Skip to content

Commit ec2a0e2

Browse files
fix issue #328 (#333)
* fix issue #328 * also fix the issue for model_comparison_approximators * minor cleaning * add comment
1 parent 4a33711 commit ec2a0e2

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def sample(
140140
**kwargs,
141141
) -> dict[str, np.ndarray]:
142142
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
143+
# at inference time, inference_variables are estimated by the networks and thus ignored in conditions
144+
conditions.pop("inference_variables", None)
143145
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
144146
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
145147
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def predict(
208208
**kwargs,
209209
) -> np.ndarray:
210210
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
211+
# at inference time, model_indices are predicted by the networks and thus ignored in conditions
212+
conditions.pop("model_indices", None)
211213
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
212214

213215
output = self._predict(**conditions, **kwargs)

0 commit comments

Comments
 (0)