Skip to content

Commit adb4238

Browse files
committed
test_add_noise_device
1 parent 93dcc72 commit adb4238

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/schedulers/test_schedulers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -722,10 +722,10 @@ def test_add_noise_device(self):
722722
scaled_sample = scheduler.scale_model_input(sample, 0.0)
723723
self.assertEqual(sample.shape, scaled_sample.shape)
724724

725-
noise = torch.randn_like(scaled_sample).to(torch_device)
726-
# t = scheduler.timesteps[5].expand(noise.shape[0])
727-
# noised = scheduler.add_noise(scaled_sample, noise, t)
728-
# self.assertEqual(noised.shape, scaled_sample.shape)
725+
noise = torch.randn(scaled_sample.shape).to(torch_device)
726+
t = scheduler.timesteps[5][None]
727+
noised = scheduler.add_noise(scaled_sample, noise, t)
728+
self.assertEqual(noised.shape, scaled_sample.shape)
729729

730730
def test_deprecated_kwargs(self):
731731
for scheduler_class in self.scheduler_classes:

0 commit comments

Comments
 (0)