@@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
162162
163163    model_cpu_offload_seq  =  "text_encoder->image_encoder->transformer->transformer_2->vae" 
164164    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
165-     _optional_components  =  ["transformer_2" , "image_encoder" , "image_processor" ]
165+     _optional_components  =  ["transformer"  ,  " transformer_2""image_encoder" , "image_processor" ]
166166
167167    def  __init__ (
168168        self ,
169169        tokenizer : AutoTokenizer ,
170170        text_encoder : UMT5EncoderModel ,
171-         transformer : WanTransformer3DModel ,
172171        vae : AutoencoderKLWan ,
173172        scheduler : FlowMatchEulerDiscreteScheduler ,
174173        image_processor : CLIPImageProcessor  =  None ,
175174        image_encoder : CLIPVisionModel  =  None ,
175+         transformer : WanTransformer3DModel  =  None ,
176176        transformer_2 : WanTransformer3DModel  =  None ,
177177        boundary_ratio : Optional [float ] =  None ,
178178        expand_timesteps : bool  =  False ,
@@ -669,12 +669,13 @@ def __call__(
669669        )
670670
671671        # Encode image embedding 
672-         transformer_dtype  =  self .transformer .dtype 
672+         transformer_dtype  =  self .transformer .dtype   if   self . transformer   is   not   None   else   self . transformer_2 . dtype 
673673        prompt_embeds  =  prompt_embeds .to (transformer_dtype )
674674        if  negative_prompt_embeds  is  not None :
675675            negative_prompt_embeds  =  negative_prompt_embeds .to (transformer_dtype )
676676
677-         if  self .config .boundary_ratio  is  None  and  not  self .config .expand_timesteps :
677+         # only wan 2.1 i2v transformer accepts image_embeds 
678+         if  self .transformer  is  not None  and  self .transformer .config .added_kv_proj_dim  is  not None :
678679            if  image_embeds  is  None :
679680                if  last_image  is  None :
680681                    image_embeds  =  self .encode_image (image , device )
@@ -709,6 +710,7 @@ def __call__(
709710            last_image ,
710711        )
711712        if  self .config .expand_timesteps :
713+             # wan 2.2 5b i2v use firt_frame_mask to mask timesteps 
712714            latents , condition , first_frame_mask  =  latents_outputs 
713715        else :
714716            latents , condition  =  latents_outputs 
0 commit comments