File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -722,10 +722,10 @@ def test_add_noise_device(self):
722722 scaled_sample = scheduler .scale_model_input (sample , 0.0 )
723723 self .assertEqual (sample .shape , scaled_sample .shape )
724724
725- noise = torch .randn_like (scaled_sample ).to (torch_device )
726- # t = scheduler.timesteps[5].expand(noise.shape[0])
727- # noised = scheduler.add_noise(scaled_sample, noise, t)
728- # self.assertEqual(noised.shape, scaled_sample.shape)
725+ noise = torch .randn (scaled_sample . shape ).to (torch_device )
726+ t = scheduler .timesteps [5 ][ None ]
727+ noised = scheduler .add_noise (scaled_sample , noise , t )
728+ self .assertEqual (noised .shape , scaled_sample .shape )
729729
730730 def test_deprecated_kwargs (self ):
731731 for scheduler_class in self .scheduler_classes :
You can’t perform that action at this time.
0 commit comments