2020from  PIL  import  Image 
2121from  transformers  import  AutoTokenizer , T5EncoderModel 
2222
23- from  diffusers  import  AutoencoderKLWan , FlowMatchEulerDiscreteScheduler , WanTransformer3DModel , WanVideoToVideoPipeline 
23+ from  diffusers  import  AutoencoderKLWan , UniPCMultistepScheduler , WanTransformer3DModel , WanVideoToVideoPipeline 
2424from  diffusers .utils .testing_utils  import  (
2525    enable_full_determinism ,
2626    require_torch_accelerator ,
2727    slow ,
2828)
2929
30- from  ..pipeline_params  import  TEXT_TO_IMAGE_BATCH_PARAMS ,  TEXT_TO_IMAGE_IMAGE_PARAMS , TEXT_TO_IMAGE_PARAMS 
30+ from  ..pipeline_params  import  TEXT_TO_IMAGE_IMAGE_PARAMS , TEXT_TO_IMAGE_PARAMS 
3131from  ..test_pipelines_common  import  (
3232    PipelineTesterMixin ,
3333)
3939class  WanVideoToVideoPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
4040    pipeline_class  =  WanVideoToVideoPipeline 
4141    params  =  TEXT_TO_IMAGE_PARAMS  -  {"cross_attention_kwargs" }
42-     batch_params  =  TEXT_TO_IMAGE_BATCH_PARAMS 
43-     image_params  =  TEXT_TO_IMAGE_IMAGE_PARAMS 
42+     batch_params  =  frozenset (["video" , "prompt" , "negative_prompt" ])
4443    image_latents_params  =  TEXT_TO_IMAGE_IMAGE_PARAMS 
4544    required_optional_params  =  frozenset (
4645        [
@@ -66,8 +65,7 @@ def get_dummy_components(self):
6665        )
6766
6867        torch .manual_seed (0 )
69-         # TODO: impl FlowDPMSolverMultistepScheduler 
70-         scheduler  =  FlowMatchEulerDiscreteScheduler (shift = 7.0 )
68+         scheduler  =  UniPCMultistepScheduler (flow_shift = 3.0 )
7169        text_encoder  =  T5EncoderModel .from_pretrained ("hf-internal-testing/tiny-random-t5" )
7270        tokenizer  =  AutoTokenizer .from_pretrained ("hf-internal-testing/tiny-random-t5" )
7371
@@ -102,7 +100,7 @@ def get_dummy_inputs(self, device, seed=0):
102100        else :
103101            generator  =  torch .Generator (device = device ).manual_seed (seed )
104102
105-         video  =  [Image .new ("RGB" , (16 , 16 ))] *  19 
103+         video  =  [Image .new ("RGB" , (16 , 16 ))] *  17 
106104        inputs  =  {
107105            "video" : video ,
108106            "prompt" : "dance monkey" ,
@@ -112,7 +110,6 @@ def get_dummy_inputs(self, device, seed=0):
112110            "guidance_scale" : 6.0 ,
113111            "height" : 16 ,
114112            "width" : 16 ,
115-             "num_frames" : 9 ,
116113            "max_sequence_length" : 16 ,
117114            "output_type" : "pt" ,
118115        }
@@ -130,15 +127,27 @@ def test_inference(self):
130127        video  =  pipe (** inputs ).frames 
131128        generated_video  =  video [0 ]
132129
133-         self .assertEqual (generated_video .shape , (9 , 3 , 16 , 16 ))
134-         expected_video  =  torch .randn (9 , 3 , 16 , 16 )
130+         self .assertEqual (generated_video .shape , (17 , 3 , 16 , 16 ))
131+         expected_video  =  torch .randn (17 , 3 , 16 , 16 )
135132        max_diff  =  np .abs (generated_video  -  expected_video ).max ()
136133        self .assertLessEqual (max_diff , 1e10 )
137134
138135    @unittest .skip ("Test not supported" ) 
139136    def  test_attention_slicing_forward_pass (self ):
140137        pass 
141138
139+     @unittest .skip ( 
140+         "WanVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"  
141+     ) 
142+     def  test_float16_inference (self ):
143+         pass 
144+ 
145+     @unittest .skip ( 
146+         "WanVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"  
147+     ) 
148+     def  test_save_load_float16 (self ):
149+         pass 
150+ 
142151
143152@slow  
144153@require_torch_accelerator  
0 commit comments