|
1 | 1 | """Tests for compositional sampling and prior score computation with adapters.""" |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | | -import keras |
5 | 4 |
|
6 | 5 | from bayesflow import ContinuousApproximator |
7 | | -from bayesflow.utils import expand_right_as |
8 | 6 |
|
9 | 7 |
|
10 | 8 | def mock_prior_score_original_space(data_dict): |
11 | | - """Mock prior score function that expects data in original (loc, scale) space.""" |
12 | | - # The function receives data in the same format the compute_prior_score_pre creates |
13 | | - # after running the inverse adapter |
| 9 | + """Mock prior score function that expects data in original space.""" |
14 | 10 | loc = data_dict["loc"] |
15 | | - scale = data_dict["scale"] |
16 | 11 |
|
17 | | - # Simple prior: N(0,1) for loc, LogNormal(0,0.5) for scale |
| 12 | + # Simple prior: N(0,1) for loc |
18 | 13 | loc_score = -loc |
19 | | - scale_score = -1.0 / scale - np.log(scale) / (0.25 * scale) |
| 14 | + return {"loc": loc_score} |
20 | 15 |
|
21 | | - return {"loc": loc_score, "scale": scale_score} |
22 | 16 |
|
23 | | - |
24 | | -def test_prior_score_transforming_adapter(simple_log_simulator, transforming_adapter, diffusion_network): |
| 17 | +def test_prior_score_identity_adapter(simple_log_simulator, identity_adapter, diffusion_network): |
25 | 18 | """Test that prior scores work correctly with transforming adapter (log transformation).""" |
26 | 19 |
|
27 | 20 | # Create approximator with transforming adapter |
28 | 21 | approximator = ContinuousApproximator( |
29 | | - adapter=transforming_adapter, |
| 22 | + adapter=identity_adapter, |
30 | 23 | inference_network=diffusion_network, |
31 | 24 | ) |
32 | 25 |
|
33 | 26 | # Generate test data and adapt it |
34 | 27 | data = simple_log_simulator.sample((2,)) |
35 | | - adapted_data = transforming_adapter(data) |
| 28 | + adapted_data = identity_adapter(data) |
36 | 29 |
|
37 | 30 | # Build approximator |
38 | 31 | approximator.build_from_data(adapted_data) |
39 | 32 |
|
40 | 33 | # Test compositional sampling |
41 | 34 | n_datasets, n_compositional = 3, 5 |
42 | 35 | conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")} |
43 | | - |
44 | | - # This should work - the compute_prior_score_pre function should handle the inverse transformation |
45 | 36 | samples = approximator.compositional_sample( |
46 | 37 | num_samples=10, |
47 | 38 | conditions=conditions, |
48 | 39 | compute_prior_score=mock_prior_score_original_space, |
49 | 40 | ) |
50 | 41 |
|
51 | 42 | assert "loc" in samples |
52 | | - assert "scale" in samples |
53 | 43 | assert samples["loc"].shape == (n_datasets, 10, 2) |
54 | | - assert samples["scale"].shape == (n_datasets, 10, 2) |
55 | | - |
56 | | - |
57 | | -def test_prior_score_jacobian_correction(simple_log_simulator, transforming_adapter, diffusion_network): |
58 | | - """Test that Jacobian correction is applied correctly in compute_prior_score_pre.""" |
59 | | - |
60 | | - # Create approximator with transforming adapter |
61 | | - approximator = ContinuousApproximator( |
62 | | - adapter=transforming_adapter, inference_network=diffusion_network, standardize=[] |
63 | | - ) |
64 | | - |
65 | | - # Build with dummy data |
66 | | - dummy_data_dict = simple_log_simulator.sample((1,)) |
67 | | - adapted_dummy_data = transforming_adapter(dummy_data_dict) |
68 | | - approximator.build_from_data(adapted_dummy_data) |
69 | | - |
70 | | - # Get the internal compute_prior_score_pre function |
71 | | - def get_compute_prior_score_pre(): |
72 | | - def compute_prior_score_pre(_samples): |
73 | | - if "inference_variables" in approximator.standardize: |
74 | | - _samples, log_det_jac_standardize = approximator.standardize_layers["inference_variables"]( |
75 | | - _samples, forward=False, log_det_jac=True |
76 | | - ) |
77 | | - else: |
78 | | - log_det_jac_standardize = keras.ops.cast(0.0, dtype="float32") |
79 | | - |
80 | | - _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) |
81 | | - adapted_samples, log_det_jac = approximator.adapter(_samples, inverse=True, strict=False, log_det_jac=True) |
82 | | - |
83 | | - prior_score = mock_prior_score_original_space(adapted_samples) |
84 | | - for key in adapted_samples: |
85 | | - if isinstance(prior_score[key], np.ndarray): |
86 | | - prior_score[key] = prior_score[key].astype("float32") |
87 | | - if len(log_det_jac) > 0 and key in log_det_jac: |
88 | | - prior_score[key] -= expand_right_as(log_det_jac[key], prior_score[key]) |
89 | | - |
90 | | - prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) |
91 | | - out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) |
92 | | - return out - keras.ops.expand_dims(log_det_jac_standardize, axis=-1) |
93 | | - |
94 | | - return compute_prior_score_pre |
95 | | - |
96 | | - compute_prior_score_pre = get_compute_prior_score_pre() |
97 | | - |
98 | | - # Test with a known transformation |
99 | | - y_samples = adapted_dummy_data["inference_variables"] |
100 | | - scores = compute_prior_score_pre(y_samples) |
101 | | - scores_np = keras.ops.convert_to_numpy(scores)[0] # Remove batch dimension |
102 | | - |
103 | | - # With Jacobian correction: score_transformed = score_original - log|J| |
104 | | - old_scores = mock_prior_score_original_space(dummy_data_dict) |
105 | | - # order of parameters is flipped due to concatenation in adapter |
106 | | - det_jac_scale = y_samples[0, :2].sum() |
107 | | - expected_scores = np.array([old_scores["scale"][0] - det_jac_scale, old_scores["loc"][0]]).flatten() |
108 | | - |
109 | | - # Check that scores are reasonably close |
110 | | - np.testing.assert_allclose(scores_np, expected_scores, rtol=1e-5, atol=1e-6) |
0 commit comments