@@ -175,6 +175,7 @@ def __init__(
175
175
image_encoder : CLIPVisionModel = None ,
176
176
transformer_2 : WanTransformer3DModel = None ,
177
177
boundary_ratio : Optional [float ] = None ,
178
+ expand_timesteps : bool = False ,
178
179
):
179
180
super ().__init__ ()
180
181
@@ -188,10 +189,10 @@ def __init__(
188
189
image_processor = image_processor ,
189
190
transformer_2 = transformer_2 ,
190
191
)
191
- self .register_to_config (boundary_ratio = boundary_ratio )
192
+ self .register_to_config (boundary_ratio = boundary_ratio , expand_timesteps = expand_timesteps )
192
193
193
- self .vae_scale_factor_temporal = 2 ** sum ( self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
194
- self .vae_scale_factor_spatial = 2 ** len ( self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
194
+ self .vae_scale_factor_temporal = self .vae .config . scale_factor_temporal if getattr (self , "vae" , None ) else 4
195
+ self .vae_scale_factor_spatial = self .vae .config . scale_factor_spatial if getattr (self , "vae" , None ) else 8
195
196
self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
196
197
self .image_processor = image_processor
197
198
@@ -419,8 +420,12 @@ def prepare_latents(
419
420
else :
420
421
latents = latents .to (device = device , dtype = dtype )
421
422
422
- image = image .unsqueeze (2 )
423
- if last_image is None :
423
+ image = image .unsqueeze (2 ) # [batch_size, channels, 1, height, width]
424
+
425
+ if self .config .expand_timesteps :
426
+ video_condition = image
427
+
428
+ elif last_image is None :
424
429
video_condition = torch .cat (
425
430
[image , image .new_zeros (image .shape [0 ], image .shape [1 ], num_frames - 1 , height , width )], dim = 2
426
431
)
@@ -453,6 +458,13 @@ def prepare_latents(
453
458
latent_condition = latent_condition .to (dtype )
454
459
latent_condition = (latent_condition - latents_mean ) * latents_std
455
460
461
+ if self .config .expand_timesteps :
462
+ first_frame_mask = torch .ones (
463
+ 1 , 1 , num_latent_frames , latent_height , latent_width , dtype = dtype , device = device
464
+ )
465
+ first_frame_mask [:, :, 0 ] = 0
466
+ return latents , latent_condition , first_frame_mask
467
+
456
468
mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
457
469
458
470
if last_image is None :
@@ -662,7 +674,7 @@ def __call__(
662
674
if negative_prompt_embeds is not None :
663
675
negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
664
676
665
- if self .config .boundary_ratio is None :
677
+ if self .config .boundary_ratio is None and not self . config . expand_timesteps :
666
678
if image_embeds is None :
667
679
if last_image is None :
668
680
image_embeds = self .encode_image (image , device )
@@ -682,7 +694,8 @@ def __call__(
682
694
last_image = self .video_processor .preprocess (last_image , height = height , width = width ).to (
683
695
device , dtype = torch .float32
684
696
)
685
- latents , condition = self .prepare_latents (
697
+
698
+ latents_outputs = self .prepare_latents (
686
699
image ,
687
700
batch_size * num_videos_per_prompt ,
688
701
num_channels_latents ,
@@ -695,6 +708,10 @@ def __call__(
695
708
latents ,
696
709
last_image ,
697
710
)
711
+ if self .config .expand_timesteps :
712
+ latents , condition , first_frame_mask = latents_outputs
713
+ else :
714
+ latents , condition = latents_outputs
698
715
699
716
# 6. Denoising loop
700
717
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
@@ -721,8 +738,17 @@ def __call__(
721
738
current_model = self .transformer_2
722
739
current_guidance_scale = guidance_scale_2
723
740
724
- latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
725
- timestep = t .expand (latents .shape [0 ])
741
+ if self .config .expand_timesteps :
742
+ latent_model_input = (1 - first_frame_mask ) * condition + first_frame_mask * latents
743
+ latent_model_input = latent_model_input .to (transformer_dtype )
744
+
745
+ # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
746
+ temp_ts = (first_frame_mask [0 ][0 ][:, ::2 , ::2 ] * t ).flatten ()
747
+ # batch_size, seq_len
748
+ timestep = temp_ts .unsqueeze (0 ).expand (latents .shape [0 ], - 1 )
749
+ else :
750
+ latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
751
+ timestep = t .expand (latents .shape [0 ])
726
752
727
753
noise_pred = current_model (
728
754
hidden_states = latent_model_input ,
@@ -766,6 +792,9 @@ def __call__(
766
792
767
793
self ._current_timestep = None
768
794
795
+ if self .config .expand_timesteps :
796
+ latents = (1 - first_frame_mask ) * condition + first_frame_mask * latents
797
+
769
798
if not output_type == "latent" :
770
799
latents = latents .to (self .vae .dtype )
771
800
latents_mean = (
0 commit comments