Skip to content

Commit 9a1ba32

Browse files
committed
add test for compute_prior_score_pre
1 parent 2a9b0e1 commit 9a1ba32

File tree

5 files changed

+214
-54
lines changed

5 files changed

+214
-54
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor:
699699
_samples, forward=False, log_det_jac=True
700700
)
701701
else:
702-
log_det_jac_standardize = 0
702+
log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32")
703703
_samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples})
704704
adapted_samples, log_det_jac = self.adapter(
705705
_samples, inverse=True, strict=False, log_det_jac=True, **kwargs
@@ -708,15 +708,12 @@ def compute_prior_score_pre(_samples: Tensor) -> Tensor:
708708
for key in adapted_samples:
709709
if isinstance(prior_score[key], np.ndarray):
710710
prior_score[key] = prior_score[key].astype("float32")
711-
if len(log_det_jac) > 0:
712-
prior_score[key] += log_det_jac[key]
711+
if len(log_det_jac) > 0 and key in log_det_jac:
712+
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])
713713

714714
prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
715-
# make a tensor
716-
out = keras.ops.concatenate(
717-
list(prior_score.values()), axis=-1
718-
) # todo: assumes same order, might be incorrect
719-
return out + expand_right_as(log_det_jac_standardize, out)
715+
out = keras.ops.concatenate(list(prior_score.values()), axis=-1)
716+
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)
720717

721718
# Test prior score function, useful for debugging
722719
test = self.inference_network.base_distribution.sample((n_datasets, num_samples))

tests/test_approximators/conftest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,56 @@ def approximator_with_summaries(request):
220220
)
221221
case _:
222222
raise ValueError("Invalid param for approximator class.")
223+
224+
225+
@pytest.fixture
226+
def simple_log_simulator():
227+
"""Create a simple simulator for testing."""
228+
import numpy as np
229+
from bayesflow.simulators import Simulator
230+
from bayesflow.utils.decorators import allow_batch_size
231+
from bayesflow.types import Shape, Tensor
232+
233+
class SimpleSimulator(Simulator):
234+
"""Simple simulator that generates mean and scale parameters."""
235+
236+
@allow_batch_size
237+
def sample(self, batch_shape: Shape) -> dict[str, Tensor]:
238+
# Generate parameters in original space
239+
loc = np.random.normal(0.0, 1.0, size=batch_shape + (2,)) # location parameters
240+
scale = np.random.lognormal(0.0, 0.5, size=batch_shape + (2,)) # scale parameters > 0
241+
242+
# Generate some dummy conditions
243+
conditions = np.random.normal(0.0, 1.0, size=batch_shape + (3,))
244+
245+
return dict(
246+
loc=loc.astype("float32"), scale=scale.astype("float32"), conditions=conditions.astype("float32")
247+
)
248+
249+
return SimpleSimulator()
250+
251+
252+
@pytest.fixture
253+
def transforming_adapter():
254+
"""Create an adapter that applies log transformation to scale parameters."""
255+
from bayesflow.adapters import Adapter
256+
257+
adapter = Adapter()
258+
adapter.to_array()
259+
adapter.convert_dtype("float64", "float32")
260+
261+
# Apply log transformation to scale parameters (to make them unbounded)
262+
adapter.log(["scale"])
263+
264+
adapter.concatenate(["loc", "scale"], into="inference_variables")
265+
adapter.concatenate(["conditions"], into="inference_conditions")
266+
adapter.keep(["inference_variables", "inference_conditions"])
267+
return adapter
268+
269+
270+
@pytest.fixture
271+
def diffusion_network():
272+
"""Create a diffusion network for compositional sampling."""
273+
from bayesflow.networks import DiffusionModel, MLP
274+
275+
return DiffusionModel(subnet=MLP(widths=[32, 32]))
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Tests for compositional sampling and prior score computation with adapters."""
2+
3+
import numpy as np
4+
import keras
5+
6+
from bayesflow import ContinuousApproximator
7+
from bayesflow.utils import expand_right_as
8+
9+
10+
def mock_prior_score_original_space(data_dict):
11+
"""Mock prior score function that expects data in original (loc, scale) space."""
12+
# The function receives data in the same format the compute_prior_score_pre creates
13+
# after running the inverse adapter
14+
loc = data_dict["loc"]
15+
scale = data_dict["scale"]
16+
17+
# Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale
18+
loc_score = -loc
19+
scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale)
20+
21+
return {"loc": loc_score, "scale": scale_score}
22+
23+
24+
def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network):
25+
"""Test that prior scores work correctly with transforming adapter (log transformation)."""
26+
27+
# Create approximator with transforming adapter
28+
approximator = ContinuousApproximator(
29+
adapter=transforming_adapter,
30+
inference_network=diffusion_network,
31+
)
32+
33+
# Generate test data and adapt it
34+
data = simple_log_simulator.sample((2,))
35+
adapted_data = transforming_adapter(data)
36+
37+
# Build approximator
38+
approximator.build_from_data(adapted_data)
39+
40+
# Test compositional sampling
41+
n_datasets, n_compositional = 3, 5
42+
conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")}
43+
44+
# This should work - the compute_prior_score_pre function should handle the inverse transformation
45+
samples = approximator.compositional_sample(
46+
num_samples=10,
47+
conditions=conditions,
48+
compute_prior_score=mock_prior_score_original_space,
49+
)
50+
51+
assert "loc" in samples
52+
assert "scale" in samples
53+
assert samples["loc"].shape == (n_datasets, 10, 2)
54+
assert samples["scale"].shape == (n_datasets, 10, 2)
55+
56+
57+
def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network):
58+
"""Test that Jacobian correction is applied correctly in compute_prior_score_pre."""
59+
60+
# Create approximator with transforming adapter
61+
approximator = ContinuousApproximator(
62+
adapter=transforming_adapter, inference_network=diffusion_network, standardize=[]
63+
)
64+
65+
# Build with dummy data
66+
dummy_data_dict = simple_log_simulator.sample((1,))
67+
adapted_dummy_data = transforming_adapter(dummy_data_dict)
68+
approximator.build_from_data(adapted_dummy_data)
69+
70+
# Get the internal compute_prior_score_pre function
71+
def get_compute_prior_score_pre():
72+
def compute_prior_score_pre(_samples):
73+
if "inference_variables" in approximator.standardize:
74+
_samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"](
75+
_samples, forward=False, log_det_jac=True
76+
)
77+
else:
78+
log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32")
79+
80+
_samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples})
81+
adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True)
82+
83+
prior_score = mock_prior_score_original_space(adapted_samples)
84+
for key in adapted_samples:
85+
if isinstance(prior_score[key], np.ndarray):
86+
prior_score[key] = prior_score[key].astype("float32")
87+
if len(log_det_jac) > 0 and key in log_det_jac:
88+
prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key])
89+
90+
prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
91+
out = keras.ops.concatenate(list(prior_score.values()), axis=-1)
92+
return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1)
93+
94+
return compute_prior_score_pre
95+
96+
compute_prior_score_pre = get_compute_prior_score_pre()
97+
98+
# Test with a known transformation
99+
y_samples = adapted_dummy_data["inference_variables"]
100+
scores = compute_prior_score_pre(y_samples)
101+
scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension
102+
103+
# With Jacobian correction: score_transformed = score_original - log|J|
104+
old_scores = mock_prior_score_original_space(dummy_data_dict)
105+
det_jac_scale = y_samples[0, 2:].sum()
106+
expected_scores = np.array([old_scores["loc"][0], old_scores["scale"][0] - det_jac_scale]).flatten()
107+
108+
# Check that scores are reasonably close
109+
np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6)

tests/test_networks/test_diffusion_model/conftest.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import keras
23

34

45
@pytest.fixture()
@@ -21,3 +22,49 @@ def edm_noise_schedule():
2122
)
2223
def noise_schedule(request):
2324
return request.getfixturevalue(request.param)
25+
26+
27+
@pytest.fixture
28+
def simple_diffusion_model():
29+
"""Create a simple diffusion model for testing compositional sampling."""
30+
from bayesflow.networks.diffusion_model import DiffusionModel
31+
from bayesflow.networks import MLP
32+
33+
return DiffusionModel(
34+
subnet=MLP(widths=[32, 32]),
35+
noise_schedule="cosine",
36+
prediction_type="noise",
37+
loss_type="noise",
38+
)
39+
40+
41+
@pytest.fixture
42+
def compositional_conditions():
43+
"""Create test conditions for compositional sampling."""
44+
batch_size = 2
45+
n_compositional = 3
46+
n_samples = 4
47+
condition_dim = 5
48+
49+
return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim))
50+
51+
52+
@pytest.fixture
53+
def compositional_state():
54+
"""Create test state for compositional sampling."""
55+
batch_size = 2
56+
n_samples = 4
57+
param_dim = 3
58+
59+
return keras.random.normal((batch_size, n_samples, param_dim))
60+
61+
62+
@pytest.fixture
63+
def mock_prior_score():
64+
"""Create a mock prior score function for testing."""
65+
66+
def prior_score_fn(theta):
67+
# Simple quadratic prior: -0.5 * ||theta||^2
68+
return -theta
69+
70+
return prior_score_fn

tests/test_networks/test_diffusion_model/test_compositional_sampling.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,6 @@
22
import pytest
33

44

5-
@pytest.fixture
6-
def simple_diffusion_model():
7-
"""Create a simple diffusion model for testing compositional sampling."""
8-
from bayesflow.networks.diffusion_model import DiffusionModel
9-
from bayesflow.networks import MLP
10-
11-
return DiffusionModel(
12-
subnet=MLP(widths=[32, 32]),
13-
noise_schedule="cosine",
14-
prediction_type="noise",
15-
loss_type="noise",
16-
)
17-
18-
19-
@pytest.fixture
20-
def compositional_conditions():
21-
"""Create test conditions for compositional sampling."""
22-
batch_size = 2
23-
n_compositional = 3
24-
n_samples = 4
25-
condition_dim = 5
26-
27-
return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim))
28-
29-
30-
@pytest.fixture
31-
def compositional_state():
32-
"""Create test state for compositional sampling."""
33-
batch_size = 2
34-
n_samples = 4
35-
param_dim = 3
36-
37-
return keras.random.normal((batch_size, n_samples, param_dim))
38-
39-
40-
@pytest.fixture
41-
def mock_prior_score():
42-
"""Create a mock prior score function for testing."""
43-
44-
def prior_score_fn(theta):
45-
# Simple quadratic prior: -0.5 * ||theta||^2
46-
return -theta
47-
48-
return prior_score_fn
49-
50-
515
def test_compositional_score_shape(
526
simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score
537
):

0 commit comments

Comments
 (0)