Skip to content

Commit 93b59ba

Browse files
committed
fix order of prior scores
1 parent 9a1ba32 commit 93b59ba

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor:
712712
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])
713713

714714
prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
715-
out = keras.ops.concatenate(list(prior_score.values()), axis=-1)
715+
out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1)
716716
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)
717717

718718
# Test prior score function, useful for debugging

tests/test_approximators/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def transforming_adapter():
261261
# Apply log transformation to scale parameters (to make them unbounded)
262262
adapter.log(["scale"])
263263

264-
adapter.concatenate(["loc", "scale"], into="inference_variables")
264+
adapter.concatenate(["scale", "loc"], into="inference_variables")
265265
adapter.concatenate(["conditions"], into="inference_conditions")
266266
adapter.keep(["inference_variables", "inference_conditions"])
267267
return adapter

tests/test_approximators/test_compositional_prior_score.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def compute_prior_score_pre(_samples):
8888
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])
8989

9090
prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
91-
out = keras.ops.concatenate(list(prior_score.values()), axis=-1)
91+
out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1)
9292
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)
9393

9494
return compute_prior_score_pre
@@ -102,8 +102,9 @@ def compute_prior_score_pre(_samples):
102102

103103
# With Jacobian correction: score_transformed = score_original - log|J|
104104
old_scores = mock_prior_score_original_space(dummy_data_dict)
105-
det_jac_scale = y_samples[0, 2:].sum()
106-
expected_scores = np.array([old_scores["loc"][0], old_scores["scale"][0] - det_jac_scale]).flatten()
105+
# order of parameters is flipped due to concatenation in adapter
106+
det_jac_scale = y_samples[0, :2].sum()
107+
expected_scores = np.array([old_scores["scale"][0] - det_jac_scale, old_scores["loc"][0]]).flatten()
107108

108109
# Check that scores are reasonably close
109110
np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6)

0 commit comments

Comments
 (0)