@@ -247,7 +247,7 @@ def test_find_summary_network_invalid_type():
247247
248248
249249def test_find_noise_schedule_by_name ():
250- from bayesflow .experimental .diffusion_model import CosineNoiseSchedule , EDMNoiseSchedule
250+ from bayesflow .experimental .diffusion_model . schedules import CosineNoiseSchedule , EDMNoiseSchedule
251251
252252 schedule = find_noise_schedule ("cosine" )
253253 assert isinstance (schedule , CosineNoiseSchedule )
@@ -262,7 +262,7 @@ def test_find_noise_schedule_unknown_name():
262262
263263
264264def test_pass_noise_schedule ():
265- from bayesflow .experimental .diffusion_model import NoiseSchedule
265+ from bayesflow .experimental .diffusion_model . schedules . noise_schedule import NoiseSchedule
266266
267267 class CustomNoiseSchedule (NoiseSchedule ):
268268 def __init__ (self ):
@@ -282,29 +282,13 @@ def derivative_log_snr(self, log_snr_t, training):
282282
283283
284284def test_pass_noise_schedule_type ():
285- from bayesflow .experimental .diffusion_model import EDMNoiseSchedule
285+ from bayesflow .experimental .diffusion_model . schedules import EDMNoiseSchedule
286286
287287 schedule = find_noise_schedule (EDMNoiseSchedule , sigma_data = 10.0 )
288288 assert isinstance (schedule , EDMNoiseSchedule )
289289 assert schedule .sigma_data == 10.0
290290
291291
292- def test_find_noise_schedule_by_dict ():
293- from bayesflow .experimental .diffusion_model import CosineNoiseSchedule , EDMNoiseSchedule
294-
295- schedule = find_noise_schedule ({"name" : "cosine" })
296- assert isinstance (schedule , CosineNoiseSchedule )
297-
298- schedule = find_noise_schedule ({"name" : "edm" , "sigma_data" : 10 })
299- assert isinstance (schedule , EDMNoiseSchedule )
300- assert schedule .sigma_data == 10
301-
302-
303- def test_find_noise_schedule_unknown_name_in_dict ():
304- with pytest .raises (ValueError ):
305- find_noise_schedule ({"name" : "unknown_noise_schedule" })
306-
307-
308292def test_find_noise_schedule_invalid_class ():
309293 with pytest .raises (TypeError ):
310294 find_noise_schedule (int )
0 commit comments