Skip to content

Commit 922040d

Browse files
committed
fix prior scores standardize
1 parent 93b59ba commit 922040d

File tree

3 files changed

+58
-84
lines changed

3 files changed

+58
-84
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
squeeze_inner_estimates_dict,
1515
concatenate_valid,
1616
concatenate_valid_shapes,
17-
expand_right_as,
1817
)
1918
from bayesflow.utils.serialization import serialize, deserialize, serializable
2019

@@ -695,25 +694,52 @@ def compositional_sample(
695694
# Prepare prior scores to handle adapter
696695
def compute_prior_score_pre(_samples: Tensor) -> Tensor:
697696
if "inference_variables" in self.standardize:
698-
_samples, log_det_jac_standardize = self.standardize_layers["inference_variables"](
699-
_samples, forward=False, log_det_jac=True
700-
)
701-
else:
702-
log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32")
697+
_samples = self.standardize_layers["inference_variables"](_samples, forward=False)
703698
_samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples})
704699
adapted_samples, log_det_jac = self.adapter(
705700
_samples, inverse=True, strict=False, log_det_jac=True, **kwargs
706701
)
702+
703+
if len(log_det_jac) > 0:
704+
problematic_keys = [key for key in log_det_jac if log_det_jac[key] != 0.0]
705+
raise NotImplementedError(
706+
f"Cannot use compositional sampling with adapters "
707+
f"that have non-zero log_det_jac. Problematic keys: {problematic_keys}"
708+
)
709+
707710
prior_score = compute_prior_score(adapted_samples)
708711
for key in adapted_samples:
709-
if isinstance(prior_score[key], np.ndarray):
710-
prior_score[key] = prior_score[key].astype("float32")
711-
if len(log_det_jac) > 0 and key in log_det_jac:
712-
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])
712+
prior_score[key] = prior_score[key].astype(np.float32)
713713

714714
prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
715715
out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1)
716-
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)
716+
717+
if "inference_variables" in self.standardize:
718+
# Apply jacobian correction from standardization
719+
# For standardization T^{-1}(z) = z * std + mean, the jacobian is diagonal with std on diagonal
720+
# The gradient of log|det(J)| w.r.t. z is 0 since log|det(J)| = sum(log(std)) is constant w.r.t. z
721+
# But we need to transform the score: score_z = score_x * std where x = T^{-1}(z)
722+
standardize_layer = self.standardize_layers["inference_variables"]
723+
724+
# Compute the correct standard deviation for all components
725+
std_components = []
726+
for idx in range(len(standardize_layer.moving_mean)):
727+
std_val = standardize_layer.moving_std(idx)
728+
std_components.append(std_val)
729+
730+
# Concatenate std components to match the shape of out
731+
if len(std_components) == 1:
732+
std = std_components[0]
733+
else:
734+
std = keras.ops.concatenate(std_components, axis=-1)
735+
736+
# Expand std to match batch dimension of out
737+
std_expanded = keras.ops.expand_dims(std, (0, 1)) # Add batch, sample dimensions
738+
std_expanded = keras.ops.tile(std_expanded, [n_datasets, num_samples, 1])
739+
740+
# Apply the jacobian: score_z = score_x * std
741+
out = out * std_expanded
742+
return out
717743

718744
# Test prior score function, useful for debugging
719745
test = self.inference_network.base_distribution.sample((n_datasets, num_samples))

tests/test_approximators/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,21 @@ def sample(self, batch_shape: Shape) -> dict[str, Tensor]:
249249
return SimpleSimulator()
250250

251251

252+
@pytest.fixture
253+
def identity_adapter():
254+
"""Create an adapter that applies no transformation to the parameters."""
255+
from bayesflow.adapters import Adapter
256+
257+
adapter = Adapter()
258+
adapter.to_array()
259+
adapter.convert_dtype("float64", "float32")
260+
261+
adapter.concatenate(["loc"], into="inference_variables")
262+
adapter.concatenate(["conditions"], into="inference_conditions")
263+
adapter.keep(["inference_variables", "inference_conditions"])
264+
return adapter
265+
266+
252267
@pytest.fixture
253268
def transforming_adapter():
254269
"""Create an adapter that applies log transformation to scale parameters."""
Lines changed: 6 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,43 @@
11
"""Tests for compositional sampling and prior score computation with adapters."""
22

33
import numpy as np
4-
import keras
54

65
from bayesflow import ContinuousApproximator
7-
from bayesflow.utils import expand_right_as
86

97

108
def mock_prior_score_original_space(data_dict):
11-
"""Mock prior score function that expects data in original (loc, scale) space."""
12-
# The function receives data in the same format the compute_prior_score_pre creates
13-
# after running the inverse adapter
9+
"""Mock prior score function that expects data in original space."""
1410
loc = data_dict["loc"]
15-
scale = data_dict["scale"]
1611

17-
# Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale
12+
# Simple prior: N(0,1) for loc
1813
loc_score = -loc
19-
scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale)
14+
return {"loc": loc_score}
2015

21-
return {"loc": loc_score, "scale": scale_score}
2216

23-
24-
def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network):
17+
def test_prior_score_identity_adapter(simple_log_simulator, identity_adapter, diffusion_network):
2518
"""Test that prior scores work correctly with transforming adapter (log transformation)."""
2619

2720
# Create approximator with transforming adapter
2821
approximator = ContinuousApproximator(
29-
adapter=transforming_adapter,
22+
adapter=identity_adapter,
3023
inference_network=diffusion_network,
3124
)
3225

3326
# Generate test data and adapt it
3427
data = simple_log_simulator.sample((2,))
35-
adapted_data = transforming_adapter(data)
28+
adapted_data = identity_adapter(data)
3629

3730
# Build approximator
3831
approximator.build_from_data(adapted_data)
3932

4033
# Test compositional sampling
4134
n_datasets, n_compositional = 3, 5
4235
conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")}
43-
44-
# This should work - the compute_prior_score_pre function should handle the inverse transformation
4536
samples = approximator.compositional_sample(
4637
num_samples=10,
4738
conditions=conditions,
4839
compute_prior_score=mock_prior_score_original_space,
4940
)
5041

5142
assert "loc" in samples
52-
assert "scale" in samples
5343
assert samples["loc"].shape == (n_datasets, 10, 2)
54-
assert samples["scale"].shape == (n_datasets, 10, 2)
55-
56-
57-
def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network):
58-
"""Test that Jacobian correction is applied correctly in compute_prior_score_pre."""
59-
60-
# Create approximator with transforming adapter
61-
approximator = ContinuousApproximator(
62-
adapter=transforming_adapter, inference_network=diffusion_network, standardize=[]
63-
)
64-
65-
# Build with dummy data
66-
dummy_data_dict = simple_log_simulator.sample((1,))
67-
adapted_dummy_data = transforming_adapter(dummy_data_dict)
68-
approximator.build_from_data(adapted_dummy_data)
69-
70-
# Get the internal compute_prior_score_pre function
71-
def get_compute_prior_score_pre():
72-
def compute_prior_score_pre(_samples):
73-
if "inference_variables" in approximator.standardize:
74-
_samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"](
75-
_samples, forward=False, log_det_jac=True
76-
)
77-
else:
78-
log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32")
79-
80-
_samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples})
81-
adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True)
82-
83-
prior_score = mock_prior_score_original_space(adapted_samples)
84-
for key in adapted_samples:
85-
if isinstance(prior_score[key], np.ndarray):
86-
prior_score[key] = prior_score[key].astype("float32")
87-
if len(log_det_jac) > 0 and key in log_det_jac:
88-
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])
89-
90-
prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
91-
out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1)
92-
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)
93-
94-
return compute_prior_score_pre
95-
96-
compute_prior_score_pre = get_compute_prior_score_pre()
97-
98-
# Test with a known transformation
99-
y_samples = adapted_dummy_data["inference_variables"]
100-
scores = compute_prior_score_pre(y_samples)
101-
scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension
102-
103-
# With Jacobian correction: score_transformed = score_original - log|J|
104-
old_scores = mock_prior_score_original_space(dummy_data_dict)
105-
# order of parameters is flipped due to concatenation in adapter
106-
det_jac_scale = y_samples[0, :2].sum()
107-
expected_scores = np.array([old_scores["scale"][0] - det_jac_scale, old_scores["loc"][0]]).flatten()
108-
109-
# Check that scores are reasonably close
110-
np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6)

0 commit comments

Comments
 (0)