Skip to content

Commit 987ce60

Browse files
committed
change to simpler separate test
The cost of the continuous models on the CI is too high
1 parent f819250 commit 987ce60

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

tests/test_networks/conftest.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ def get_config(self):
6161

6262

6363
@pytest.fixture()
64-
def diffusion_model_edm_F_subnet_concatenate():
64+
def diffusion_model_edm_F_subnet_separate_inputs():
6565
from bayesflow.networks import DiffusionModel
6666

6767
return DiffusionModel(
6868
subnet=ConcatenateMLP([8, 8]),
69-
integrate_kwargs={"method": "rk45", "steps": 250},
69+
integrate_kwargs={"method": "rk45", "steps": 4},
7070
noise_schedule="edm",
7171
prediction_type="F",
7272
concatenate_subnet_input=False,
@@ -144,11 +144,11 @@ def flow_matching():
144144

145145

146146
@pytest.fixture()
147-
def flow_matching_subnet_concatenate():
147+
def flow_matching_subnet_separate_inputs():
148148
from bayesflow.networks import FlowMatching
149149

150150
return FlowMatching(
151-
subnet=ConcatenateMLP([8, 8]), integrate_kwargs={"method": "rk45", "steps": 100}, concatenate_subnet_input=False
151+
subnet=ConcatenateMLP([8, 8]), integrate_kwargs={"method": "rk45", "steps": 4}, concatenate_subnet_input=False
152152
)
153153

154154

@@ -160,10 +160,10 @@ def consistency_model():
160160

161161

162162
@pytest.fixture()
163-
def consistency_model_subnet_concatenate():
163+
def consistency_model_subnet_separate_inputs():
164164
from bayesflow.networks import ConsistencyModel
165165

166-
return ConsistencyModel(total_steps=100, subnet=ConcatenateMLP([8, 8]), concatenate_subnet_input=False)
166+
return ConsistencyModel(total_steps=4, subnet=ConcatenateMLP([8, 8]), concatenate_subnet_input=False)
167167

168168

169169
@pytest.fixture()
@@ -230,12 +230,9 @@ def typical_point_inference_network_subnet():
230230
"affine_coupling_flow",
231231
"spline_coupling_flow",
232232
"flow_matching",
233-
pytest.param("flow_matching_subnet_concatenate"),
234233
"free_form_flow",
235234
"consistency_model",
236-
pytest.param("consistency_model_subnet_concatenate"),
237235
pytest.param("diffusion_model_edm_F"),
238-
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
239236
pytest.param("diffusion_model_edm_noise", marks=pytest.mark.slow),
240237
pytest.param("diffusion_model_cosine_velocity", marks=pytest.mark.slow),
241238
pytest.param("diffusion_model_cosine_F", marks=pytest.mark.slow),
@@ -263,16 +260,12 @@ def inference_network_subnet(request):
263260

264261
@pytest.fixture(
265262
params=[
266-
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
267263
"affine_coupling_flow",
268264
"spline_coupling_flow",
269265
"flow_matching",
270-
pytest.param("flow_matching_subnet_concatenate"),
271266
"free_form_flow",
272267
"consistency_model",
273-
pytest.param("consistency_model_subnet_concatenate"),
274268
pytest.param("diffusion_model_edm_F"),
275-
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
276269
pytest.param(
277270
"diffusion_model_edm_noise",
278271
marks=[
@@ -309,6 +302,18 @@ def generative_inference_network(request):
309302
return request.getfixturevalue(request.param)
310303

311304

305+
@pytest.fixture(
306+
params=[
307+
pytest.param("flow_matching_subnet_separate_inputs"),
308+
pytest.param("consistency_model_subnet_separate_inputs"),
309+
pytest.param("diffusion_model_edm_F_subnet_separate_inputs"),
310+
],
311+
scope="function",
312+
)
313+
def inference_network_subnet_separate_inputs(request):
314+
return request.getfixturevalue(request.param)
315+
316+
312317
@pytest.fixture(scope="function")
313318
def time_series_network(summary_dim):
314319
from bayesflow.networks import TimeSeriesNetwork

tests/test_networks/test_inference_networks.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,16 @@ def test_compute_metrics(inference_network, random_samples, random_conditions):
162162

163163
metrics = inference_network.compute_metrics(random_samples, conditions=random_conditions)
164164
assert "loss" in metrics
165+
166+
167+
def test_subnet_separate_inputs(inference_network_subnet_separate_inputs, random_samples, random_conditions):
168+
xz_shape = keras.ops.shape(random_samples)
169+
conditions_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None
170+
inference_network_subnet_separate_inputs.build(xz_shape, conditions_shape)
171+
172+
assert inference_network_subnet_separate_inputs.built is True
173+
174+
# check the model has variables
175+
assert inference_network_subnet_separate_inputs.variables, "Model has no variables."
176+
177+
inference_network_subnet_separate_inputs(random_samples, random_conditions, inverse=True)

0 commit comments

Comments
 (0)