Skip to content

Commit e983cf7

Browse files
committed
add some tests
1 parent 5969bd3 commit e983cf7

File tree

2 files changed

+179
-1
lines changed

2 files changed

+179
-1
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def compositional_score(
737737
weighted_prior_score = (1.0 - time) * compute_prior_score(xz)
738738

739739
# Combine scores using compositional formula, mean over individual scores and scale with n to get sum
740-
weighted_individual_scores = individual_scores - weighted_prior_score
740+
weighted_individual_scores = individual_scores - keras.ops.expand_dims(weighted_prior_score, axis=1)
741741
summed_individual_scores = n_compositional * ops.mean(weighted_individual_scores, axis=1)
742742

743743
# Combined score
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import keras
2+
import pytest
3+
4+
5+
@pytest.fixture
6+
def simple_diffusion_model():
7+
"""Create a simple diffusion model for testing compositional sampling."""
8+
from bayesflow.networks.diffusion_model import DiffusionModel
9+
from bayesflow.networks import MLP
10+
11+
return DiffusionModel(
12+
subnet=MLP(widths=[32, 32]),
13+
noise_schedule="cosine",
14+
prediction_type="noise",
15+
loss_type="noise",
16+
)
17+
18+
19+
@pytest.fixture
20+
def compositional_conditions():
21+
"""Create test conditions for compositional sampling."""
22+
batch_size = 2
23+
n_compositional = 3
24+
n_samples = 4
25+
condition_dim = 5
26+
27+
return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim))
28+
29+
30+
@pytest.fixture
31+
def compositional_state():
32+
"""Create test state for compositional sampling."""
33+
batch_size = 2
34+
n_samples = 4
35+
param_dim = 3
36+
37+
return keras.random.normal((batch_size, n_samples, param_dim))
38+
39+
40+
@pytest.fixture
41+
def mock_prior_score():
42+
"""Create a mock prior score function for testing."""
43+
44+
def prior_score_fn(theta):
45+
# Simple quadratic prior: -0.5 * ||theta||^2
46+
return -theta
47+
48+
return prior_score_fn
49+
50+
51+
def test_compositional_score_shape(
52+
simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score
53+
):
54+
"""Test that compositional score returns correct shapes."""
55+
# Build the model
56+
state_shape = keras.ops.shape(compositional_state)
57+
conditions_shape = keras.ops.shape(compositional_conditions)
58+
simple_diffusion_model.build(state_shape, conditions_shape)
59+
60+
time = 0.5
61+
62+
score = simple_diffusion_model.compositional_score(
63+
xz=compositional_state,
64+
time=time,
65+
conditions=compositional_conditions,
66+
compute_prior_score=mock_prior_score,
67+
training=False,
68+
)
69+
70+
expected_shape = keras.ops.shape(compositional_state)
71+
actual_shape = keras.ops.shape(score)
72+
73+
assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), (
74+
f"Expected shape {expected_shape}, got {actual_shape}"
75+
)
76+
77+
78+
def test_compositional_score_no_conditions_raises_error(simple_diffusion_model, compositional_state, mock_prior_score):
79+
"""Test that compositional score raises error when conditions is None."""
80+
simple_diffusion_model.build(keras.ops.shape(compositional_state), None)
81+
82+
with pytest.raises(ValueError, match="Conditions are required for compositional sampling"):
83+
simple_diffusion_model.compositional_score(
84+
xz=compositional_state, time=0.5, conditions=None, compute_prior_score=mock_prior_score, training=False
85+
)
86+
87+
88+
def test_inverse_compositional_basic(
89+
simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score
90+
):
91+
"""Test basic compositional inverse sampling."""
92+
state_shape = keras.ops.shape(compositional_state)
93+
conditions_shape = keras.ops.shape(compositional_conditions)
94+
simple_diffusion_model.build(state_shape, conditions_shape)
95+
96+
# Test inverse sampling with ODE method
97+
result = simple_diffusion_model._inverse_compositional(
98+
z=compositional_state,
99+
conditions=compositional_conditions,
100+
compute_prior_score=mock_prior_score,
101+
density=False,
102+
training=False,
103+
method="euler",
104+
steps=5,
105+
start_time=1.0,
106+
stop_time=0.0,
107+
)
108+
109+
expected_shape = keras.ops.shape(compositional_state)
110+
actual_shape = keras.ops.shape(result)
111+
112+
assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), (
113+
f"Expected shape {expected_shape}, got {actual_shape}"
114+
)
115+
116+
117+
def test_inverse_compositional_euler_maruyama_with_corrector(
118+
simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score
119+
):
120+
"""Test compositional inverse sampling with Euler-Maruyama and corrector steps."""
121+
state_shape = keras.ops.shape(compositional_state)
122+
conditions_shape = keras.ops.shape(compositional_conditions)
123+
simple_diffusion_model.build(state_shape, conditions_shape)
124+
125+
result = simple_diffusion_model._inverse_compositional(
126+
z=compositional_state,
127+
conditions=compositional_conditions,
128+
compute_prior_score=mock_prior_score,
129+
density=False,
130+
training=False,
131+
method="euler_maruyama",
132+
steps=5,
133+
corrector_steps=2,
134+
start_time=1.0,
135+
stop_time=0.0,
136+
)
137+
138+
expected_shape = keras.ops.shape(compositional_state)
139+
actual_shape = keras.ops.shape(result)
140+
141+
assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), (
142+
f"Expected shape {expected_shape}, got {actual_shape}"
143+
)
144+
145+
146+
@pytest.mark.parametrize("noise_schedule_name", ["cosine", "edm"])
147+
def test_compositional_sampling_with_different_schedules(
148+
noise_schedule_name, compositional_state, compositional_conditions, mock_prior_score
149+
):
150+
"""Test compositional sampling with different noise schedules."""
151+
from bayesflow.networks.diffusion_model import DiffusionModel
152+
from bayesflow.networks import MLP
153+
154+
diffusion_model = DiffusionModel(
155+
subnet=MLP(widths=[32, 32]),
156+
noise_schedule=noise_schedule_name,
157+
prediction_type="noise",
158+
loss_type="noise",
159+
)
160+
161+
state_shape = keras.ops.shape(compositional_state)
162+
conditions_shape = keras.ops.shape(compositional_conditions)
163+
diffusion_model.build(state_shape, conditions_shape)
164+
165+
score = diffusion_model.compositional_score(
166+
xz=compositional_state,
167+
time=0.5,
168+
conditions=compositional_conditions,
169+
compute_prior_score=mock_prior_score,
170+
training=False,
171+
)
172+
173+
expected_shape = keras.ops.shape(compositional_state)
174+
actual_shape = keras.ops.shape(score)
175+
176+
assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), (
177+
f"Expected shape {expected_shape}, got {actual_shape}"
178+
)

0 commit comments

Comments
 (0)