@@ -94,9 +94,10 @@ class WanImageConditioningLatentEncodeProcessor(ProcessorMixin):
9494 - mask: The conditioning frame mask for the input image/video.
9595 """
9696
97- def __init__ (self , output_names : List [str ]):
97+ def __init__ (self , output_names : List [str ], * , use_last_frame : bool = False ):
9898 super ().__init__ ()
9999 self .output_names = output_names
100+ self .use_last_frame = use_last_frame
100101 assert len (self .output_names ) == 4
101102
102103 def forward (
@@ -117,8 +118,12 @@ def forward(
117118 video = video .permute (0 , 2 , 1 , 3 , 4 ).contiguous () # [B, F, C, H, W] -> [B, C, F, H, W]
118119
119120 num_frames = video .size (2 )
120- first_frame , remaining_frames = video [:, :, :1 ], video [:, :, 1 :]
121- video = torch .cat ([first_frame , torch .zeros_like (remaining_frames )], dim = 2 )
121+ if not self .use_last_frame :
122+ first_frame , remaining_frames = video [:, :, :1 ], video [:, :, 1 :]
123+ video = torch .cat ([first_frame , torch .zeros_like (remaining_frames )], dim = 2 )
124+ else :
125+ first_frame , remaining_frames , last_frame = video [:, :, :1 ], video [:, :, 1 :- 1 ], video [:, :, - 1 :]
126+ video = torch .cat ([first_frame , torch .zeros_like (remaining_frames ), last_frame ], dim = 2 )
122127
123128 # Image conditioning uses argmax sampling, so we use "mode" here
124129 if compute_posterior :
@@ -139,7 +144,10 @@ def forward(
139144
140145 temporal_downsample = 2 ** sum (vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
141146 mask = latents .new_ones (latents .shape [0 ], 1 , num_frames , latents .shape [3 ], latents .shape [4 ])
142- mask [:, :, 1 :] = 0
147+ if not self .use_last_frame :
148+ mask [:, :, 1 :] = 0
149+ else :
150+ mask [:, :, 1 :- 1 ] = 0
143151 first_frame_mask = mask [:, :, :1 ]
144152 first_frame_mask = torch .repeat_interleave (first_frame_mask , dim = 2 , repeats = temporal_downsample )
145153 mask = torch .cat ([first_frame_mask , mask [:, :, 1 :]], dim = 2 )
@@ -164,9 +172,10 @@ class WanImageEncodeProcessor(ProcessorMixin):
164172 - image_embeds: The CLIP vision model image embeddings of the input image.
165173 """
166174
167- def __init__ (self , output_names : List [str ]):
175+ def __init__ (self , output_names : List [str ], * , use_last_frame : bool = False ):
168176 super ().__init__ ()
169177 self .output_names = output_names
178+ self .use_last_frame = use_last_frame
170179 assert len (self .output_names ) == 1
171180
172181 def forward (
@@ -178,15 +187,19 @@ def forward(
178187 ) -> Dict [str , torch .Tensor ]:
179188 device = image_encoder .device
180189 dtype = image_encoder .dtype
181-
182- if video is not None :
183- image = video [:, 0 ] # [B, F, C, H, W] -> [B, C, H, W] (take first frame)
184-
185- assert image .ndim == 4 , f"Expected 4D tensor, got { image .ndim } D tensor"
190+ last_image = None
186191
187192 # We know the image here is in the range [-1, 1] (probably a little overshot if using bilinear interpolation), but
188193 # the processor expects it to be in the range [0, 1].
189- image = FF .normalize (image , min = 0.0 , max = 1.0 )
194+ image = image if video is None else video [:, 0 ] # [B, F, C, H, W] -> [B, C, H, W] (take first frame)
195+ image = FF .normalize (image , min = 0.0 , max = 1.0 , dim = 1 )
196+ assert image .ndim == 4 , f"Expected 4D tensor, got { image .ndim } D tensor"
197+
198+ if self .use_last_frame :
199+ last_image = image if video is None else video [:, - 1 ]
200+ last_image = FF .normalize (last_image , min = 0.0 , max = 1.0 , dim = 1 )
201+ image = torch .stack ([image , last_image ], dim = 0 )
202+
190203 image = image_processor (images = image .float (), do_rescale = False , do_convert_rgb = False , return_tensors = "pt" )
191204 image = image .to (device = device , dtype = dtype )
192205 image_embeds = image_encoder (** image , output_hidden_states = True )
@@ -224,18 +237,23 @@ def __init__(
224237 cache_dir = cache_dir ,
225238 )
226239
240+ use_last_frame = self .transformer_config .pos_embed_seq_len is not None
241+
227242 if condition_model_processors is None :
228- condition_model_processors = [T5Processor (["encoder_hidden_states" , "prompt_attention_mask " ])]
243+ condition_model_processors = [T5Processor (["encoder_hidden_states" , "__drop__ " ])]
229244 if latent_model_processors is None :
230245 latent_model_processors = [WanLatentEncodeProcessor (["latents" , "latents_mean" , "latents_std" ])]
231246
232247 if self .transformer_config .image_dim is not None :
233248 latent_model_processors .append (
234249 WanImageConditioningLatentEncodeProcessor (
235- ["latent_condition" , "__drop__" , "__drop__" , "latent_condition_mask" ]
250+ ["latent_condition" , "__drop__" , "__drop__" , "latent_condition_mask" ],
251+ use_last_frame = use_last_frame ,
236252 )
237253 )
238- latent_model_processors .append (WanImageEncodeProcessor (["encoder_hidden_states_image" ]))
254+ latent_model_processors .append (
255+ WanImageEncodeProcessor (["encoder_hidden_states_image" ], use_last_frame = use_last_frame )
256+ )
239257
240258 self .condition_model_processors = condition_model_processors
241259 self .latent_model_processors = latent_model_processors
@@ -380,7 +398,6 @@ def prepare_conditions(
380398 input_keys = set (conditions .keys ())
381399 conditions = super ().prepare_conditions (** conditions )
382400 conditions = {k : v for k , v in conditions .items () if k not in input_keys }
383- conditions .pop ("prompt_attention_mask" , None )
384401 return conditions
385402
386403 @torch .no_grad ()
@@ -480,6 +497,7 @@ def validation(
480497 pipeline : Union [WanPipeline , WanImageToVideoPipeline ],
481498 prompt : str ,
482499 image : Optional [PIL .Image .Image ] = None ,
500+ last_image : Optional [PIL .Image .Image ] = None ,
483501 video : Optional [List [PIL .Image .Image ]] = None ,
484502 height : Optional [int ] = None ,
485503 width : Optional [int ] = None ,
@@ -501,9 +519,11 @@ def validation(
501519 if self .transformer_config .image_dim is not None :
502520 if image is None and video is None :
503521 raise ValueError ("Either image or video must be provided for Wan I2V validation." )
504- if image is None :
505- image = video [0 ]
522+ image = image if image is not None else video [0 ]
506523 generation_kwargs ["image" ] = image
524+ if self .transformer_config .pos_embed_seq_len is not None :
525+ last_image = last_image if last_image is not None else image if video is None else video [- 1 ]
526+ generation_kwargs ["last_image" ] = last_image
507527 generation_kwargs = get_non_null_items (generation_kwargs )
508528 video = pipeline (** generation_kwargs ).frames [0 ]
509529 return [VideoArtifact (value = video )]
0 commit comments