1- # Copyright 2024 The HuggingFace Team.
1+ # Copyright 2025 The HuggingFace Team.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
@@ -107,7 +107,7 @@ def get_dummy_inputs(self, device, seed=0):
107107 "video" : video ,
108108 "mask" : mask ,
109109 "prompt" : "dance monkey" ,
110- "negative_prompt" : "negative" , # TODO
110+ "negative_prompt" : "negative" ,
111111 "generator" : generator ,
112112 "num_inference_steps" : 2 ,
113113 "guidance_scale" : 6.0 ,
@@ -128,13 +128,17 @@ def test_inference(self):
128128 pipe .set_progress_bar_config (disable = None )
129129
130130 inputs = self .get_dummy_inputs (device )
131- video = pipe (** inputs ).frames
132- generated_video = video [ 0 ]
131+ video = pipe (** inputs ).frames [ 0 ]
132+ self . assertEqual ( video . shape , ( 17 , 3 , 16 , 16 ))
133133
134- self .assertEqual (generated_video .shape , (17 , 3 , 16 , 16 ))
135- expected_video = torch .randn (17 , 3 , 16 , 16 )
136- max_diff = np .abs (generated_video - expected_video ).max ()
137- self .assertLessEqual (max_diff , 1e10 )
134+ # fmt: off
135+ expected_slice = [0.4523 , 0.45198 , 0.44872 , 0.45326 , 0.45211 , 0.45258 , 0.45344 , 0.453 , 0.52431 , 0.52572 , 0.50701 , 0.5118 , 0.53717 , 0.53093 , 0.50557 , 0.51402 ]
136+ # fmt: on
137+
138+ video_slice = video .flatten ()
139+ video_slice = torch .cat ([video_slice [:8 ], video_slice [- 8 :]])
140+ video_slice = [round (x , 5 ) for x in video_slice .tolist ()]
141+ self .assertTrue (np .allclose (video_slice , expected_slice , atol = 1e-3 ))
138142
139143 def test_inference_with_single_reference_image (self ):
140144 device = "cpu"
@@ -146,13 +150,17 @@ def test_inference_with_single_reference_image(self):
146150
147151 inputs = self .get_dummy_inputs (device )
148152 inputs ["reference_images" ] = Image .new ("RGB" , (16 , 16 ))
149- video = pipe (** inputs ).frames
150- generated_video = video [0 ]
153+ video = pipe (** inputs ).frames [0 ]
154+ self .assertEqual (video .shape , (17 , 3 , 16 , 16 ))
155+
156+ # fmt: off
157+ expected_slice = [0.45247 , 0.45214 , 0.44874 , 0.45314 , 0.45171 , 0.45299 , 0.45428 , 0.45317 , 0.51378 , 0.52658 , 0.53361 , 0.52303 , 0.46204 , 0.50435 , 0.52555 , 0.51342 ]
158+ # fmt: on
151159
152- self . assertEqual ( generated_video . shape , ( 17 , 3 , 16 , 16 ) )
153- expected_video = torch .randn ( 17 , 3 , 16 , 16 )
154- max_diff = np . abs ( generated_video - expected_video ). max ()
155- self .assertLessEqual ( max_diff , 1e10 )
160+ video_slice = video . flatten ( )
161+ video_slice = torch .cat ([ video_slice [: 8 ], video_slice [ - 8 :]] )
162+ video_slice = [ round ( x , 5 ) for x in video_slice . tolist ()]
163+ self .assertTrue ( np . allclose ( video_slice , expected_slice , atol = 1e-3 ) )
156164
157165 def test_inference_with_multiple_reference_image (self ):
158166 device = "cpu"
@@ -164,13 +172,17 @@ def test_inference_with_multiple_reference_image(self):
164172
165173 inputs = self .get_dummy_inputs (device )
166174 inputs ["reference_images" ] = [[Image .new ("RGB" , (16 , 16 ))] * 2 ]
167- video = pipe (** inputs ).frames
168- generated_video = video [0 ]
175+ video = pipe (** inputs ).frames [0 ]
176+ self .assertEqual (video .shape , (17 , 3 , 16 , 16 ))
177+
178+ # fmt: off
179+ expected_slice = [0.45321 , 0.45221 , 0.44818 , 0.45375 , 0.45268 , 0.4519 , 0.45271 , 0.45253 , 0.51244 , 0.52223 , 0.51253 , 0.51321 , 0.50743 , 0.51177 , 0.51626 , 0.50983 ]
180+ # fmt: on
169181
170- self . assertEqual ( generated_video . shape , ( 17 , 3 , 16 , 16 ) )
171- expected_video = torch .randn ( 17 , 3 , 16 , 16 )
172- max_diff = np . abs ( generated_video - expected_video ). max ()
173- self .assertLessEqual ( max_diff , 1e10 )
182+ video_slice = video . flatten ( )
183+ video_slice = torch .cat ([ video_slice [: 8 ], video_slice [ - 8 :]] )
184+ video_slice = [ round ( x , 5 ) for x in video_slice . tolist ()]
185+ self .assertTrue ( np . allclose ( video_slice , expected_slice , atol = 1e-3 ) )
174186
175187 @unittest .skip ("Test not supported" )
176188 def test_attention_slicing_forward_pass (self ):
0 commit comments