@@ -175,6 +175,7 @@ def __init__(
175175 image_encoder : CLIPVisionModel = None ,
176176 transformer_2 : WanTransformer3DModel = None ,
177177 boundary_ratio : Optional [float ] = None ,
178+ expand_timesteps : bool = False ,
178179 ):
179180 super ().__init__ ()
180181
@@ -188,10 +189,10 @@ def __init__(
188189 image_processor = image_processor ,
189190 transformer_2 = transformer_2 ,
190191 )
191- self .register_to_config (boundary_ratio = boundary_ratio )
192+ self .register_to_config (boundary_ratio = boundary_ratio , expand_timesteps = expand_timesteps )
192193
193- self .vae_scale_factor_temporal = 2 ** sum ( self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
194- self .vae_scale_factor_spatial = 2 ** len ( self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
194+ self .vae_scale_factor_temporal = self .vae .config . scale_factor_temporal if getattr (self , "vae" , None ) else 4
195+ self .vae_scale_factor_spatial = self .vae .config . scale_factor_spatial if getattr (self , "vae" , None ) else 8
195196 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
196197 self .image_processor = image_processor
197198
@@ -419,8 +420,12 @@ def prepare_latents(
419420 else :
420421 latents = latents .to (device = device , dtype = dtype )
421422
422- image = image .unsqueeze (2 )
423- if last_image is None :
423+ image = image .unsqueeze (2 ) # [batch_size, channels, 1, height, width]
424+
425+ if self .config .expand_timesteps :
426+ video_condition = image
427+
428+ elif last_image is None :
424429 video_condition = torch .cat (
425430 [image , image .new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
426431 )
@@ -453,6 +458,13 @@ def prepare_latents(
453458 latent_condition = latent_condition .to (dtype )
454459 latent_condition = (latent_condition - latents_mean ) * latents_std
455460
461+ if self .config .expand_timesteps :
462+ first_frame_mask = torch .ones (
463+ 1 , 1 , num_latent_frames , latent_height , latent_width , dtype = dtype , device = device
464+ )
465+ first_frame_mask [:, :, 0 ] = 0
466+ return latents , latent_condition , first_frame_mask
467+
456468 mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
457469
458470 if last_image is None :
@@ -662,7 +674,7 @@ def __call__(
662674 if negative_prompt_embeds is not None :
663675 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
664676
665- if self .config .boundary_ratio is None :
677+ if self .config .boundary_ratio is None and not self . config . expand_timesteps :
666678 if image_embeds is None :
667679 if last_image is None :
668680 image_embeds = self .encode_image (image , device )
@@ -682,7 +694,8 @@ def __call__(
682694 last_image = self .video_processor .preprocess (last_image , height = height , width = width ).to (
683695 device , dtype = torch .float32
684696 )
685- latents , condition = self .prepare_latents (
697+
698+ latents_outputs = self .prepare_latents (
686699 image ,
687700 batch_size * num_videos_per_prompt ,
688701 num_channels_latents ,
@@ -695,6 +708,10 @@ def __call__(
695708 latents ,
696709 last_image ,
697710 )
711+ if self .config .expand_timesteps :
712+ latents , condition , first_frame_mask = latents_outputs
713+ else :
714+ latents , condition = latents_outputs
698715
699716 # 6. Denoising loop
700717 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
@@ -721,8 +738,17 @@ def __call__(
721738 current_model = self .transformer_2
722739 current_guidance_scale = guidance_scale_2
723740
724- latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
725- timestep = t .expand (latents .shape [0 ])
741+ if self .config .expand_timesteps :
742+ latent_model_input = (1 - first_frame_mask ) * condition + first_frame_mask * latents
743+ latent_model_input = latent_model_input .to (transformer_dtype )
744+
745+ # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
746+ temp_ts = (first_frame_mask [0 ][0 ][:, ::2 , ::2 ] * t ).flatten ()
747+ # batch_size, seq_len
748+ timestep = temp_ts .unsqueeze (0 ).expand (latents .shape [0 ], - 1 )
749+ else :
750+ latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
751+ timestep = t .expand (latents .shape [0 ])
726752
727753 noise_pred = current_model (
728754 hidden_states = latent_model_input ,
@@ -766,6 +792,9 @@ def __call__(
766792
767793 self ._current_timestep = None
768794
795+ if self .config .expand_timesteps :
796+ latents = (1 - first_frame_mask ) * condition + first_frame_mask * latents
797+
769798 if not output_type == "latent" :
770799 latents = latents .to (self .vae .dtype )
771800 latents_mean = (
0 commit comments