Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import PIL
import regex as re
import torch
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
Expand Down Expand Up @@ -49,11 +49,11 @@
>>> import numpy as np
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> from transformers import CLIPVisionModel
>>> from transformers import CLIPVisionModelWithProjection

>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
>>> image_encoder = CLIPVisionModel.from_pretrained(
>>> image_encoder = CLIPVisionModelWithProjection.from_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change? the authors just changed to CLIPVisionModel https://github.com/huggingface/diffusers/pull/10975/files

... model_id, subfolder="image_encoder", torch_dtype=torch.float32
... )
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
Expand Down Expand Up @@ -109,14 +109,30 @@ def prompt_clean(text):


def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
encoder_output: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
Comment on lines +119 to +124
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can move this out of retrieve_latents if you'd prefer, it would affect this section more

if isinstance(generator, list):
latent_condition = [
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)

return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
return (encoder_output.latents - latents_mean) * latents_std
else:
raise AttributeError("Could not access latents of provided encoder_output")

Expand Down Expand Up @@ -155,7 +171,7 @@ def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
image_encoder: CLIPVisionModel,
image_encoder: CLIPVisionModelWithProjection,
image_processor: CLIPImageProcessor,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
Expand Down Expand Up @@ -385,13 +401,6 @@ def prepare_latents(
)
video_condition = video_condition.to(device=device, dtype=dtype)

if isinstance(generator, list):
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
latents = latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)

latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
Expand All @@ -401,7 +410,14 @@ def prepare_latents(
latents.device, latents.dtype
)

latent_condition = (latent_condition - latents_mean) * latents_std
if isinstance(generator, list):
latent_condition = [
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)

mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
Expand Down
Loading