We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f7fb73e commit ac2b820Copy full SHA for ac2b820
tests/schedulers/test_schedulers.py
@@ -723,9 +723,9 @@ def test_add_noise_device(self):
723
self.assertEqual(sample.shape, scaled_sample.shape)
724
725
noise = torch.randn_like(scaled_sample).to(torch_device)
726
- t = scheduler.timesteps[5][None]
727
- # noised = scheduler.add_noise(scaled_sample, noise, t)
728
- # self.assertEqual(noised.shape, scaled_sample.shape)
+ t = scheduler.timesteps[5].expand(noise.shape[0])
+ noised = scheduler.add_noise(scaled_sample, noise, t)
+ self.assertEqual(noised.shape, scaled_sample.shape)
729
730
def test_deprecated_kwargs(self):
731
for scheduler_class in self.scheduler_classes:
0 commit comments