@@ -30,9 +30,10 @@ class DDPMPipelineFastTests(unittest.TestCase):
30
30
def dummy_uncond_unet (self ):
31
31
torch .manual_seed (0 )
32
32
model = UNet2DModel (
33
- block_out_channels = (32 , 64 ),
34
- layers_per_block = 2 ,
35
- sample_size = 32 ,
33
+ block_out_channels = (4 , 8 ),
34
+ layers_per_block = 1 ,
35
+ norm_num_groups = 4 ,
36
+ sample_size = 8 ,
36
37
in_channels = 3 ,
37
38
out_channels = 3 ,
38
39
down_block_types = ("DownBlock2D" , "AttnDownBlock2D" ),
@@ -58,10 +59,8 @@ def test_fast_inference(self):
58
59
image_slice = image [0 , - 3 :, - 3 :, - 1 ]
59
60
image_from_tuple_slice = image_from_tuple [0 , - 3 :, - 3 :, - 1 ]
60
61
61
- assert image .shape == (1 , 32 , 32 , 3 )
62
- expected_slice = np .array (
63
- [9.956e-01 , 5.785e-01 , 4.675e-01 , 9.930e-01 , 0.0 , 1.000 , 1.199e-03 , 2.648e-04 , 5.101e-04 ]
64
- )
62
+ assert image .shape == (1 , 8 , 8 , 3 )
63
+ expected_slice = np .array ([0.0 , 0.9996672 , 0.00329116 , 1.0 , 0.9995991 , 1.0 , 0.0060907 , 0.00115037 , 0.0 ])
65
64
66
65
assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
67
66
assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
@@ -83,7 +82,7 @@ def test_inference_predict_sample(self):
83
82
image_slice = image [0 , - 3 :, - 3 :, - 1 ]
84
83
image_eps_slice = image_eps [0 , - 3 :, - 3 :, - 1 ]
85
84
86
- assert image .shape == (1 , 32 , 32 , 3 )
85
+ assert image .shape == (1 , 8 , 8 , 3 )
87
86
tolerance = 1e-2 if torch_device != "mps" else 3e-2
88
87
assert np .abs (image_slice .flatten () - image_eps_slice .flatten ()).max () < tolerance
89
88
0 commit comments