Skip to content

Commit 465e12e

Browse files
committed
fix wan i2v pipeline bugs
1 parent e031caf commit 465e12e

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import PIL
2020
import regex as re
2121
import torch
22-
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel
22+
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2525
from ...image_processor import PipelineImageInput
@@ -46,29 +46,37 @@
4646
Examples:
4747
```python
4848
>>> import torch
49+
>>> import numpy as np
4950
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
5051
>>> from diffusers.utils import export_to_video, load_image
52+
>>> from transformers import CLIPVisionModel
5153
52-
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-1.3B-720P-Diffusers
54+
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
5355
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
56+
>>> image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
5457
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
55-
>>> pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
58+
>>> pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
5659
>>> pipe.to("cuda")
5760
58-
>>> height, width = 480, 832
5961
>>> image = load_image(
6062
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
61-
... ).resize((width, height))
63+
... )
64+
>>> max_area = 480 * 832
65+
>>> aspect_ratio = image.height / image.width
66+
>>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
67+
>>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
68+
>>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
69+
>>> image = image.resize((width, height))
6270
>>> prompt = (
6371
... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
6472
... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
6573
... )
6674
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
6775
6876
>>> output = pipe(
69-
... image=image, prompt=prompt, negative_prompt=negative_prompt, num_frames=81, guidance_scale=5.0
77+
... image=image, prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=81, guidance_scale=5.0
7078
... ).frames[0]
71-
>>> export_to_video(output, "output.mp4", fps=15)
79+
>>> export_to_video(output, "output.mp4", fps=16)
7280
```
7381
"""
7482

@@ -137,7 +145,7 @@ def __init__(
137145
self,
138146
tokenizer: AutoTokenizer,
139147
text_encoder: UMT5EncoderModel,
140-
image_encoder: CLIPVisionModelWithProjection,
148+
image_encoder: CLIPVisionModel,
141149
image_processor: CLIPImageProcessor,
142150
transformer: WanTransformer3DModel,
143151
vae: AutoencoderKLWan,
@@ -204,7 +212,7 @@ def _get_t5_prompt_embeds(
204212
def encode_image(self, image: PipelineImageInput):
205213
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
206214
image_embeds = self.image_encoder(**image, output_hidden_states=True)
207-
return image_embeds.hidden_states[-1]
215+
return image_embeds.hidden_states[-2]
208216

209217
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
210218
def encode_prompt(

0 commit comments

Comments
 (0)