@@ -894,10 +894,10 @@ def test_ddpm_ddim_equality(self):
894
894
generator = torch .manual_seed (0 )
895
895
ddim_image = ddim (generator = generator , num_inference_steps = 1000 , eta = 1.0 , output_type = "numpy" )["sample" ]
896
896
897
- # the values aren't exactly equal, but the images look the same upon visual inspection
897
+ # the values aren't exactly equal, but the images look the same visually
898
898
assert np .abs (ddpm_image - ddim_image ).max () < 1e-1
899
899
900
- @slow
900
+ @unittest . skip ( "(Anton) The test is failing for large batch sizes, needs investigation" )
901
901
def test_ddpm_ddim_equality_batched (self ):
902
902
model_id = "google/ddpm-cifar10-32"
903
903
@@ -909,12 +909,12 @@ def test_ddpm_ddim_equality_batched(self):
909
909
ddim = DDIMPipeline (unet = unet , scheduler = ddim_scheduler )
910
910
911
911
generator = torch .manual_seed (0 )
912
- ddpm_images = ddpm (batch_size = 2 , generator = generator , output_type = "numpy" )["sample" ]
912
+ ddpm_images = ddpm (batch_size = 4 , generator = generator , output_type = "numpy" )["sample" ]
913
913
914
914
generator = torch .manual_seed (0 )
915
- ddim_images = ddim (batch_size = 2 , generator = generator , num_inference_steps = 1000 , eta = 1.0 , output_type = "numpy" )[
915
+ ddim_images = ddim (batch_size = 4 , generator = generator , num_inference_steps = 1000 , eta = 1.0 , output_type = "numpy" )[
916
916
"sample"
917
917
]
918
918
919
- # the values aren't exactly equal, but the images look the same upon visual inspection
919
+ # the values aren't exactly equal, but the images look the same visually
920
920
assert np .abs (ddpm_images - ddim_images ).max () < 1e-1
0 commit comments