Skip to content

Commit d8854b8

Browse files
yiyixuxua-r-r-o-w
andauthored
[wan2.2] add 5b i2v (#12006)
* add 5b ti2v * remove a copy * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: Aryan <[email protected]> * Apply suggestions from code review --------- Co-authored-by: Aryan <[email protected]>
1 parent 327e251 commit d8854b8

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ def check_inputs(
370370
):
371371
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
372372

373-
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents
374373
def prepare_latents(
375374
self,
376375
image: PipelineImageInput,

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def __init__(
175175
image_encoder: CLIPVisionModel = None,
176176
transformer_2: WanTransformer3DModel = None,
177177
boundary_ratio: Optional[float] = None,
178+
expand_timesteps: bool = False,
178179
):
179180
super().__init__()
180181

@@ -188,10 +189,10 @@ def __init__(
188189
image_processor=image_processor,
189190
transformer_2=transformer_2,
190191
)
191-
self.register_to_config(boundary_ratio=boundary_ratio)
192+
self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
192193

193-
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
194-
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
194+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
195+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
195196
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
196197
self.image_processor = image_processor
197198

@@ -419,8 +420,12 @@ def prepare_latents(
419420
else:
420421
latents = latents.to(device=device, dtype=dtype)
421422

422-
image = image.unsqueeze(2)
423-
if last_image is None:
423+
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
424+
425+
if self.config.expand_timesteps:
426+
video_condition = image
427+
428+
elif last_image is None:
424429
video_condition = torch.cat(
425430
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
426431
)
@@ -453,6 +458,13 @@ def prepare_latents(
453458
latent_condition = latent_condition.to(dtype)
454459
latent_condition = (latent_condition - latents_mean) * latents_std
455460

461+
if self.config.expand_timesteps:
462+
first_frame_mask = torch.ones(
463+
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
464+
)
465+
first_frame_mask[:, :, 0] = 0
466+
return latents, latent_condition, first_frame_mask
467+
456468
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
457469

458470
if last_image is None:
@@ -662,7 +674,7 @@ def __call__(
662674
if negative_prompt_embeds is not None:
663675
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
664676

665-
if self.config.boundary_ratio is None:
677+
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
666678
if image_embeds is None:
667679
if last_image is None:
668680
image_embeds = self.encode_image(image, device)
@@ -682,7 +694,8 @@ def __call__(
682694
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
683695
device, dtype=torch.float32
684696
)
685-
latents, condition = self.prepare_latents(
697+
698+
latents_outputs = self.prepare_latents(
686699
image,
687700
batch_size * num_videos_per_prompt,
688701
num_channels_latents,
@@ -695,6 +708,10 @@ def __call__(
695708
latents,
696709
last_image,
697710
)
711+
if self.config.expand_timesteps:
712+
latents, condition, first_frame_mask = latents_outputs
713+
else:
714+
latents, condition = latents_outputs
698715

699716
# 6. Denoising loop
700717
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -721,8 +738,17 @@ def __call__(
721738
current_model = self.transformer_2
722739
current_guidance_scale = guidance_scale_2
723740

724-
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
725-
timestep = t.expand(latents.shape[0])
741+
if self.config.expand_timesteps:
742+
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
743+
latent_model_input = latent_model_input.to(transformer_dtype)
744+
745+
# seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
746+
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
747+
# batch_size, seq_len
748+
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
749+
else:
750+
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
751+
timestep = t.expand(latents.shape[0])
726752

727753
noise_pred = current_model(
728754
hidden_states=latent_model_input,
@@ -766,6 +792,9 @@ def __call__(
766792

767793
self._current_timestep = None
768794

795+
if self.config.expand_timesteps:
796+
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
797+
769798
if not output_type == "latent":
770799
latents = latents.to(self.vae.dtype)
771800
latents_mean = (

0 commit comments

Comments
 (0)