1414
1515import unittest
1616
17- import numpy as np
1817import torch
1918from PIL import Image
2019from transformers import (
@@ -147,11 +146,15 @@ def test_inference(self):
147146 inputs = self .get_dummy_inputs (device )
148147 video = pipe (** inputs ).frames
149148 generated_video = video [0 ]
150-
151149 self .assertEqual (generated_video .shape , (9 , 3 , 16 , 16 ))
152- expected_video = torch .randn (9 , 3 , 16 , 16 )
153- max_diff = np .abs (generated_video - expected_video ).max ()
154- self .assertLessEqual (max_diff , 1e10 )
150+
151+ # fmt: off
152+ expected_slice = torch .tensor ([0.4525 , 0.4525 , 0.4497 , 0.4536 , 0.452 , 0.4529 , 0.454 , 0.4535 , 0.5072 , 0.5527 , 0.5165 , 0.5244 , 0.5481 , 0.5282 , 0.5208 , 0.5214 ])
153+ # fmt: on
154+
155+ generated_slice = generated_video .flatten ()
156+ generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
157+ self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
155158
156159 @unittest .skip ("Test not supported" )
157160 def test_attention_slicing_forward_pass (self ):
@@ -162,7 +165,25 @@ def test_inference_batch_single_identical(self):
162165 pass
163166
164167
165- class WanFLFToVideoPipelineFastTests (WanImageToVideoPipelineFastTests ):
168+ class WanFLFToVideoPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
169+ pipeline_class = WanImageToVideoPipeline
170+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs" , "height" , "width" }
171+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
172+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
173+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
174+ required_optional_params = frozenset (
175+ [
176+ "num_inference_steps" ,
177+ "generator" ,
178+ "latents" ,
179+ "return_dict" ,
180+ "callback_on_step_end" ,
181+ "callback_on_step_end_tensor_inputs" ,
182+ ]
183+ )
184+ test_xformers_attention = False
185+ supports_dduf = False
186+
166187 def get_dummy_components (self ):
167188 torch .manual_seed (0 )
168189 vae = AutoencoderKLWan (
@@ -247,3 +268,32 @@ def get_dummy_inputs(self, device, seed=0):
247268 "output_type" : "pt" ,
248269 }
249270 return inputs
271+
272+ def test_inference (self ):
273+ device = "cpu"
274+
275+ components = self .get_dummy_components ()
276+ pipe = self .pipeline_class (** components )
277+ pipe .to (device )
278+ pipe .set_progress_bar_config (disable = None )
279+
280+ inputs = self .get_dummy_inputs (device )
281+ video = pipe (** inputs ).frames
282+ generated_video = video [0 ]
283+ self .assertEqual (generated_video .shape , (9 , 3 , 16 , 16 ))
284+
285+ # fmt: off
286+ expected_slice = torch .tensor ([0.4531 , 0.4527 , 0.4498 , 0.4542 , 0.4526 , 0.4527 , 0.4534 , 0.4534 , 0.5061 , 0.5185 , 0.5283 , 0.5181 , 0.5309 , 0.5365 , 0.5113 , 0.5244 ])
287+ # fmt: on
288+
289+ generated_slice = generated_video .flatten ()
290+ generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
291+ self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ))
292+
293+ @unittest .skip ("Test not supported" )
294+ def test_attention_slicing_forward_pass (self ):
295+ pass
296+
297+ @unittest .skip ("TODO: revisit failing as it requires a very high threshold to pass" )
298+ def test_inference_batch_single_identical (self ):
299+ pass
0 commit comments