@@ -149,12 +149,11 @@ def test_inference(self):
149149 self .assertEqual (generated_video .shape , (9 , 3 , 16 , 16 ))
150150
151151 # fmt: off
152- expected_slice = torch .tensor ([0.4531 , 0.4527 , 0.4498 , 0.4542 , 0.4526 , 0.4527 , 0.4534 , 0.4534 , 0.5061 , 0.5185 , 0.5283 , 0.5181 , 0.5309 , 0.5365 , 0.5113 , 0.5244 ])
152+ expected_slice = torch .tensor ([0.4525 , 0.4525 , 0.4497 , 0.4536 , 0.452 , 0.4529 , 0.454 , 0.4535 , 0.5072 , 0.5527 , 0.5165 , 0.5244 , 0.5481 , 0.5282 , 0.5208 , 0.5214 ])
153153 # fmt: on
154154
155155 generated_slice = generated_video .flatten ()
156156 generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
157- print ("image2video:" , [round (x , 4 ) for x in generated_slice .tolist ()])
158157 self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
159158
160159 @unittest .skip ("Test not supported" )
@@ -166,7 +165,25 @@ def test_inference_batch_single_identical(self):
166165 pass
167166
168167
169- class WanFLFToVideoPipelineFastTests (WanImageToVideoPipelineFastTests ):
168+ class WanFLFToVideoPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
169+ pipeline_class = WanImageToVideoPipeline
170+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs" , "height" , "width" }
171+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
172+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
173+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
174+ required_optional_params = frozenset (
175+ [
176+ "num_inference_steps" ,
177+ "generator" ,
178+ "latents" ,
179+ "return_dict" ,
180+ "callback_on_step_end" ,
181+ "callback_on_step_end_tensor_inputs" ,
182+ ]
183+ )
184+ test_xformers_attention = False
185+ supports_dduf = False
186+
170187 def get_dummy_components (self ):
171188 torch .manual_seed (0 )
172189 vae = AutoencoderKLWan (
@@ -251,3 +268,32 @@ def get_dummy_inputs(self, device, seed=0):
251268 "output_type" : "pt" ,
252269 }
253270 return inputs
271+
272+ def test_inference (self ):
273+ device = "cpu"
274+
275+ components = self .get_dummy_components ()
276+ pipe = self .pipeline_class (** components )
277+ pipe .to (device )
278+ pipe .set_progress_bar_config (disable = None )
279+
280+ inputs = self .get_dummy_inputs (device )
281+ video = pipe (** inputs ).frames
282+ generated_video = video [0 ]
283+ self .assertEqual (generated_video .shape , (9 , 3 , 16 , 16 ))
284+
285+ # fmt: off
286+ expected_slice = torch .tensor ([0.4531 , 0.4527 , 0.4498 , 0.4542 , 0.4526 , 0.4527 , 0.4534 , 0.4534 , 0.5061 , 0.5185 , 0.5283 , 0.5181 , 0.5309 , 0.5365 , 0.5113 , 0.5244 ])
287+ # fmt: on
288+
289+ generated_slice = generated_video .flatten ()
290+ generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
291+ self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
292+
293+ @unittest .skip ("Test not supported" )
294+ def test_attention_slicing_forward_pass (self ):
295+ pass
296+
297+ @unittest .skip ("TODO: revisit failing as it requires a very high threshold to pass" )
298+ def test_inference_batch_single_identical (self ):
299+ pass
0 commit comments