Skip to content

Commit 62101a4

Browse files
committed
update example
1 parent 9ad3e31 commit 62101a4

File tree

1 file changed

+41
-11
lines changed

1 file changed

+41
-11
lines changed

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,58 @@
4949
Examples:
5050
```python
5151
>>> import torch
52-
>>> from diffusers.utils import export_to_video
53-
>>> from diffusers import AutoencoderKLWan, WanPipeline
52+
>>> import PIL.Image
53+
>>> from diffusers import AutoencoderKLWan, WanVACEPipeline
5454
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
55-
56-
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-diffusers, Wan-AI/Wan2.1-T2V-1.3B-diffusers
57-
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-diffusers"
55+
>>> from diffusers.utils import export_to_video, load_image
56+
def prepare_video_and_mask(first_img: PIL.Image.Image, last_img: PIL.Image.Image, height: int, width: int, num_frames: int):
57+
first_img = first_img.resize((width, height))
58+
last_img = last_img.resize((width, height))
59+
frames = []
60+
frames.append(first_img)
61+
# Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
62+
# whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
63+
# match the original code.
64+
frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
65+
frames.append(last_img)
66+
mask_black = PIL.Image.new("L", (width, height), 0)
67+
mask_white = PIL.Image.new("L", (width, height), 255)
68+
mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
69+
return frames, mask
70+
71+
>>> # Available checkpoints: Wan-AI/Wan2.1-VACE-1.3B-diffusers, Wan-AI/Wan2.1-VACE-14B-diffusers
72+
>>> model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
5873
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
59-
>>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
60-
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
74+
>>> pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
75+
>>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
6176
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
6277
>>> pipe.to("cuda")
6378
64-
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
79+
>>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
6580
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
81+
>>> first_frame = load_image(
82+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
83+
... )
84+
>>> last_frame = load_image(
85+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png>>> "
86+
... )
87+
88+
>>> height = 512
89+
>>> width = 512
90+
>>> num_frames = 81
91+
>>> video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)
6692
6793
>>> output = pipe(
94+
... video=video,
95+
... mask=mask,
6896
... prompt=prompt,
6997
... negative_prompt=negative_prompt,
70-
... height=720,
71-
... width=1280,
72-
... num_frames=81,
98+
... height=height,
99+
... width=width,
100+
... num_frames=num_frames,
101+
... num_inference_steps=30,
73102
... guidance_scale=5.0,
103+
... generator=torch.Generator().manual_seed(42),
74104
... ).frames[0]
75105
>>> export_to_video(output, "output.mp4", fps=16)
76106
```

0 commit comments

Comments
 (0)