-
Couldn't load subscription status.
- Fork 6.5k
Wan Pipeline scaling fix, type hint warning, multi generator fix #11007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |||||||||||||||||
| import PIL | ||||||||||||||||||
| import regex as re | ||||||||||||||||||
| import torch | ||||||||||||||||||
| from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel | ||||||||||||||||||
| from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel | ||||||||||||||||||
|
|
||||||||||||||||||
| from ...callbacks import MultiPipelineCallbacks, PipelineCallback | ||||||||||||||||||
| from ...image_processor import PipelineImageInput | ||||||||||||||||||
|
|
@@ -49,11 +49,11 @@ | |||||||||||||||||
| >>> import numpy as np | ||||||||||||||||||
| >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline | ||||||||||||||||||
| >>> from diffusers.utils import export_to_video, load_image | ||||||||||||||||||
| >>> from transformers import CLIPVisionModel | ||||||||||||||||||
| >>> from transformers import CLIPVisionModelWithProjection | ||||||||||||||||||
hlky marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||
|
|
||||||||||||||||||
| >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers | ||||||||||||||||||
| >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | ||||||||||||||||||
| >>> image_encoder = CLIPVisionModel.from_pretrained( | ||||||||||||||||||
| >>> image_encoder = CLIPVisionModelWithProjection.from_pretrained( | ||||||||||||||||||
|
||||||||||||||||||
| ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 | ||||||||||||||||||
| ... ) | ||||||||||||||||||
| >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) | ||||||||||||||||||
|
|
@@ -109,14 +109,30 @@ def prompt_clean(text): | |||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def retrieve_latents( | ||||||||||||||||||
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" | ||||||||||||||||||
| encoder_output: torch.Tensor, | ||||||||||||||||||
| latents_mean: torch.Tensor, | ||||||||||||||||||
| latents_std: torch.Tensor, | ||||||||||||||||||
| generator: Optional[torch.Generator] = None, | ||||||||||||||||||
| sample_mode: str = "sample", | ||||||||||||||||||
| ): | ||||||||||||||||||
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | ||||||||||||||||||
| encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std | ||||||||||||||||||
| encoder_output.latent_dist.logvar = torch.clamp( | ||||||||||||||||||
| (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 | ||||||||||||||||||
| ) | ||||||||||||||||||
| encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar) | ||||||||||||||||||
| encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar) | ||||||||||||||||||
|
Comment on lines
+119
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can move this out of diffusers/src/diffusers/pipelines/wan/pipeline_wan_i2v.py Lines 413 to 420 in e461b61
|
||||||||||||||||||
| return encoder_output.latent_dist.sample(generator) | ||||||||||||||||||
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | ||||||||||||||||||
| encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std | ||||||||||||||||||
| encoder_output.latent_dist.logvar = torch.clamp( | ||||||||||||||||||
| (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 | ||||||||||||||||||
| ) | ||||||||||||||||||
| encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar) | ||||||||||||||||||
| encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar) | ||||||||||||||||||
| return encoder_output.latent_dist.mode() | ||||||||||||||||||
| elif hasattr(encoder_output, "latents"): | ||||||||||||||||||
| return encoder_output.latents | ||||||||||||||||||
| return (encoder_output.latents - latents_mean) * latents_std | ||||||||||||||||||
| else: | ||||||||||||||||||
| raise AttributeError("Could not access latents of provided encoder_output") | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -155,7 +171,7 @@ def __init__( | |||||||||||||||||
| self, | ||||||||||||||||||
| tokenizer: AutoTokenizer, | ||||||||||||||||||
| text_encoder: UMT5EncoderModel, | ||||||||||||||||||
| image_encoder: CLIPVisionModel, | ||||||||||||||||||
| image_encoder: CLIPVisionModelWithProjection, | ||||||||||||||||||
hlky marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||
| image_processor: CLIPImageProcessor, | ||||||||||||||||||
| transformer: WanTransformer3DModel, | ||||||||||||||||||
| vae: AutoencoderKLWan, | ||||||||||||||||||
|
|
@@ -385,13 +401,6 @@ def prepare_latents( | |||||||||||||||||
| ) | ||||||||||||||||||
| video_condition = video_condition.to(device=device, dtype=dtype) | ||||||||||||||||||
|
|
||||||||||||||||||
| if isinstance(generator, list): | ||||||||||||||||||
| latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator] | ||||||||||||||||||
| latents = latent_condition = torch.cat(latent_condition) | ||||||||||||||||||
| else: | ||||||||||||||||||
| latent_condition = retrieve_latents(self.vae.encode(video_condition), generator) | ||||||||||||||||||
| latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) | ||||||||||||||||||
|
|
||||||||||||||||||
| latents_mean = ( | ||||||||||||||||||
| torch.tensor(self.vae.config.latents_mean) | ||||||||||||||||||
| .view(1, self.vae.config.z_dim, 1, 1, 1) | ||||||||||||||||||
|
|
@@ -401,7 +410,14 @@ def prepare_latents( | |||||||||||||||||
| latents.device, latents.dtype | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| latent_condition = (latent_condition - latents_mean) * latents_std | ||||||||||||||||||
| if isinstance(generator, list): | ||||||||||||||||||
| latent_condition = [ | ||||||||||||||||||
| retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator | ||||||||||||||||||
| ] | ||||||||||||||||||
| latent_condition = torch.cat(latent_condition) | ||||||||||||||||||
| else: | ||||||||||||||||||
| latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator) | ||||||||||||||||||
| latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) | ||||||||||||||||||
|
|
||||||||||||||||||
| mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) | ||||||||||||||||||
| mask_lat_size[:, :, list(range(1, num_frames))] = 0 | ||||||||||||||||||
|
|
||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.