Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")

# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents
def prepare_latents(
self,
image: PipelineImageInput,
Expand Down
82 changes: 61 additions & 21 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
image_encoder: CLIPVisionModel = None,
transformer_2: WanTransformer3DModel = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False,
):
super().__init__()

Expand All @@ -188,10 +189,10 @@ def __init__(
image_processor=image_processor,
transformer_2=transformer_2,
)
self.register_to_config(boundary_ratio=boundary_ratio)
self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)

self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.image_processor = image_processor

Expand Down Expand Up @@ -419,8 +420,12 @@ def prepare_latents(
else:
latents = latents.to(device=device, dtype=dtype)

image = image.unsqueeze(2)
if last_image is None:
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]

if self.config.expand_timesteps:
video_condition = image

elif last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
Expand Down Expand Up @@ -453,6 +458,13 @@ def prepare_latents(
latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std

if self.config.expand_timesteps:
first_frame_mask = torch.ones(
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
)
first_frame_mask[:, :, 0] = 0
return latents, latent_condition, first_frame_mask

mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)

if last_image is None:
Expand Down Expand Up @@ -662,7 +674,7 @@ def __call__(
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)

if self.config.boundary_ratio is None:
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
if image_embeds is None:
if last_image is None:
image_embeds = self.encode_image(image, device)
Expand All @@ -682,19 +694,35 @@ def __call__(
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
device, dtype=torch.float32
)
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
last_image,
)

if self.config.expand_timesteps:
latents, condition, first_frame_mask = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
last_image,
)
else:
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
last_image,
)

# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
Expand All @@ -721,8 +749,17 @@ def __call__(
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2

latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
if self.config.expand_timesteps:
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
latent_model_input = latent_model_input.to(transformer_dtype)

# seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])

noise_pred = current_model(
hidden_states=latent_model_input,
Expand Down Expand Up @@ -766,6 +803,9 @@ def __call__(

self._current_timestep = None

if self.config.expand_timesteps:
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents

if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
Expand Down
Loading