@@ -711,16 +711,16 @@ def test_add_noise_device(self):
711711 scheduler = scheduler_class (** scheduler_config )
712712 scheduler .set_timesteps (self .default_num_inference_steps )
713713
714- # sample = self.dummy_sample.to(torch_device)
715- # if scheduler_class == CMStochasticIterativeScheduler:
716- # # Get valid timestep based on sigma_max, which should always be in timestep schedule.
717- # scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
718- # scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
719- # elif scheduler_class == EDMEulerScheduler:
720- # scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
721- # else:
722- # scaled_sample = scheduler.scale_model_input(sample, 0.0)
723- # self.assertEqual(sample.shape, scaled_sample.shape)
714+ sample = self .dummy_sample .to (torch_device )
715+ if scheduler_class == CMStochasticIterativeScheduler :
716+ # Get valid timestep based on sigma_max, which should always be in timestep schedule.
717+ scaled_sigma_max = scheduler .sigma_to_t (scheduler .config .sigma_max )
718+ scaled_sample = scheduler .scale_model_input (sample , scaled_sigma_max )
719+ elif scheduler_class == EDMEulerScheduler :
720+ scaled_sample = scheduler .scale_model_input (sample , scheduler .timesteps [- 1 ])
721+ else :
722+ scaled_sample = scheduler .scale_model_input (sample , 0.0 )
723+ self .assertEqual (sample .shape , scaled_sample .shape )
724724
725725 # noise = torch.randn_like(scaled_sample).to(torch_device)
726726 # t = scheduler.timesteps[5][None]
0 commit comments