Skip to content

Commit 58b6081

Browse files
committed
test_add_noise_device
1 parent ba8f1df commit 58b6081

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tests/schedulers/test_schedulers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -711,16 +711,16 @@ def test_add_noise_device(self):
711711
scheduler = scheduler_class(**scheduler_config)
712712
scheduler.set_timesteps(self.default_num_inference_steps)
713713

714-
# sample = self.dummy_sample.to(torch_device)
715-
# if scheduler_class == CMStochasticIterativeScheduler:
716-
# # Get valid timestep based on sigma_max, which should always be in timestep schedule.
717-
# scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
718-
# scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
719-
# elif scheduler_class == EDMEulerScheduler:
720-
# scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
721-
# else:
722-
# scaled_sample = scheduler.scale_model_input(sample, 0.0)
723-
# self.assertEqual(sample.shape, scaled_sample.shape)
714+
sample = self.dummy_sample.to(torch_device)
715+
if scheduler_class == CMStochasticIterativeScheduler:
716+
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
717+
scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
718+
scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
719+
elif scheduler_class == EDMEulerScheduler:
720+
scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
721+
else:
722+
scaled_sample = scheduler.scale_model_input(sample, 0.0)
723+
self.assertEqual(sample.shape, scaled_sample.shape)
724724

725725
# noise = torch.randn_like(scaled_sample).to(torch_device)
726726
# t = scheduler.timesteps[5][None]

0 commit comments

Comments
 (0)