Skip to content

Commit f740784

Browse files
committed
Fix dispatch tests for dms
1 parent cfacbd8 commit f740784

File tree

1 file changed

+3
-19
lines changed

1 file changed

+3
-19
lines changed

tests/test_utils/test_dispatch.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_find_summary_network_invalid_type():
247247

248248

249249
def test_find_noise_schedule_by_name():
250-
from bayesflow.experimental.diffusion_model import CosineNoiseSchedule, EDMNoiseSchedule
250+
from bayesflow.experimental.diffusion_model.schedules import CosineNoiseSchedule, EDMNoiseSchedule
251251

252252
schedule = find_noise_schedule("cosine")
253253
assert isinstance(schedule, CosineNoiseSchedule)
@@ -262,7 +262,7 @@ def test_find_noise_schedule_unknown_name():
262262

263263

264264
def test_pass_noise_schedule():
265-
from bayesflow.experimental.diffusion_model import NoiseSchedule
265+
from bayesflow.experimental.diffusion_model.schedules.noise_schedule import NoiseSchedule
266266

267267
class CustomNoiseSchedule(NoiseSchedule):
268268
def __init__(self):
@@ -282,29 +282,13 @@ def derivative_log_snr(self, log_snr_t, training):
282282

283283

284284
def test_pass_noise_schedule_type():
285-
from bayesflow.experimental.diffusion_model import EDMNoiseSchedule
285+
from bayesflow.experimental.diffusion_model.schedules import EDMNoiseSchedule
286286

287287
schedule = find_noise_schedule(EDMNoiseSchedule, sigma_data=10.0)
288288
assert isinstance(schedule, EDMNoiseSchedule)
289289
assert schedule.sigma_data == 10.0
290290

291291

292-
def test_find_noise_schedule_by_dict():
293-
from bayesflow.experimental.diffusion_model import CosineNoiseSchedule, EDMNoiseSchedule
294-
295-
schedule = find_noise_schedule({"name": "cosine"})
296-
assert isinstance(schedule, CosineNoiseSchedule)
297-
298-
schedule = find_noise_schedule({"name": "edm", "sigma_data": 10})
299-
assert isinstance(schedule, EDMNoiseSchedule)
300-
assert schedule.sigma_data == 10
301-
302-
303-
def test_find_noise_schedule_unknown_name_in_dict():
304-
with pytest.raises(ValueError):
305-
find_noise_schedule({"name": "unknown_noise_schedule"})
306-
307-
308292
def test_find_noise_schedule_invalid_class():
309293
with pytest.raises(TypeError):
310294
find_noise_schedule(int)

0 commit comments

Comments
 (0)