@@ -175,6 +175,7 @@ def __init__(
175175 image_encoder : CLIPVisionModel = None ,
176176 transformer_2 : WanTransformer3DModel = None ,
177177 boundary_ratio : Optional [float ] = None ,
178+ expand_timesteps : bool = False ,
178179 ):
179180 super ().__init__ ()
180181
@@ -188,10 +189,10 @@ def __init__(
188189 image_processor = image_processor ,
189190 transformer_2 = transformer_2 ,
190191 )
191- self .register_to_config (boundary_ratio = boundary_ratio )
192+ self .register_to_config (boundary_ratio = boundary_ratio , expand_timesteps = expand_timesteps )
192193
193- self .vae_scale_factor_temporal = 2 ** sum ( self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
194- self .vae_scale_factor_spatial = 2 ** len ( self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
194+ self .vae_scale_factor_temporal = self .vae .config . scale_factor_temporal if getattr (self , "vae" , None ) else 4
195+ self .vae_scale_factor_spatial = self .vae .config . scale_factor_spatial if getattr (self , "vae" , None ) else 8
195196 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
196197 self .image_processor = image_processor
197198
@@ -419,8 +420,12 @@ def prepare_latents(
419420 else :
420421 latents = latents .to (device = device , dtype = dtype )
421422
422- image = image .unsqueeze (2 )
423- if last_image is None :
423+ image = image .unsqueeze (2 ) # [batch_size, channels, 1, height, width]
424+
425+ if self .config .expand_timesteps :
426+ video_condition = image
427+
428+ elif last_image is None :
424429 video_condition = torch .cat (
425430 [image , image .new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
426431 )
@@ -453,6 +458,13 @@ def prepare_latents(
453458 latent_condition = latent_condition .to (dtype )
454459 latent_condition = (latent_condition - latents_mean ) * latents_std
455460
461+ if self .config .expand_timesteps :
462+ first_frame_mask = torch .ones (
463+ 1 , 1 , num_latent_frames , latent_height , latent_width , dtype = dtype , device = device
464+ )
465+ first_frame_mask [:, :, 0 ] = 0
466+ return latents , latent_condition , first_frame_mask
467+
456468 mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
457469
458470 if last_image is None :
@@ -662,7 +674,7 @@ def __call__(
662674 if negative_prompt_embeds is not None :
663675 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
664676
665- if self .config .boundary_ratio is None :
677+ if self .config .boundary_ratio is None and not self . config . expand_timesteps :
666678 if image_embeds is None :
667679 if last_image is None :
668680 image_embeds = self .encode_image (image , device )
@@ -682,19 +694,35 @@ def __call__(
682694 last_image = self .video_processor .preprocess (last_image , height = height , width = width ).to (
683695 device , dtype = torch .float32
684696 )
685- latents , condition = self .prepare_latents (
686- image ,
687- batch_size * num_videos_per_prompt ,
688- num_channels_latents ,
689- height ,
690- width ,
691- num_frames ,
692- torch .float32 ,
693- device ,
694- generator ,
695- latents ,
696- last_image ,
697- )
697+
698+ if self .config .expand_timesteps :
699+ latents , condition , first_frame_mask = self .prepare_latents (
700+ image ,
701+ batch_size * num_videos_per_prompt ,
702+ num_channels_latents ,
703+ height ,
704+ width ,
705+ num_frames ,
706+ torch .float32 ,
707+ device ,
708+ generator ,
709+ latents ,
710+ last_image ,
711+ )
712+ else :
713+ latents , condition = self .prepare_latents (
714+ image ,
715+ batch_size * num_videos_per_prompt ,
716+ num_channels_latents ,
717+ height ,
718+ width ,
719+ num_frames ,
720+ torch .float32 ,
721+ device ,
722+ generator ,
723+ latents ,
724+ last_image ,
725+ )
698726
699727 # 6. Denoising loop
700728 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
@@ -721,8 +749,17 @@ def __call__(
721749 current_model = self .transformer_2
722750 current_guidance_scale = guidance_scale_2
723751
724- latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
725- timestep = t .expand (latents .shape [0 ])
752+ if self .config .expand_timesteps :
753+ latent_model_input = (1 - first_frame_mask ) * condition + first_frame_mask * latents
754+ latent_model_input = latent_model_input .to (transformer_dtype )
755+
756+ # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
757+ temp_ts = (first_frame_mask [0 ][0 ][:, ::2 , ::2 ] * t ).flatten ()
758+ # batch_size, seq_len
759+ timestep = temp_ts .unsqueeze (0 ).expand (latents .shape [0 ], - 1 )
760+ else :
761+ latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
762+ timestep = t .expand (latents .shape [0 ])
726763
727764 noise_pred = current_model (
728765 hidden_states = latent_model_input ,
@@ -766,6 +803,9 @@ def __call__(
766803
767804 self ._current_timestep = None
768805
806+ if self .config .expand_timesteps :
807+ latents = (1 - first_frame_mask ) * condition + first_frame_mask * latents
808+
769809 if not output_type == "latent" :
770810 latents = latents .to (self .vae .dtype )
771811 latents_mean = (
0 commit comments