Skip to content

Commit b65412f

Browse files
authored
Merge branch 'main' into issue129
2 parents 714b969 + d8c617c commit b65412f

File tree

14 files changed

+994
-20
lines changed

14 files changed

+994
-20
lines changed

docs/source/en/api/pipelines/wan.md

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,46 @@ output = pipe(
133133
export_to_video(output, "wan-i2v.mp4", fps=16)
134134
```
135135

136+
### Video to Video Generation
137+
138+
```python
139+
import torch
140+
from diffusers.utils import load_video, export_to_video
141+
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
142+
143+
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
144+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
145+
vae = AutoencoderKLWan.from_pretrained(
146+
model_id, subfolder="vae", torch_dtype=torch.float32
147+
)
148+
pipe = WanVideoToVideoPipeline.from_pretrained(
149+
model_id, vae=vae, torch_dtype=torch.bfloat16
150+
)
151+
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
152+
pipe.scheduler = UniPCMultistepScheduler.from_config(
153+
pipe.scheduler.config, flow_shift=flow_shift
154+
)
155+
# change to pipe.to("cuda") if you have sufficient VRAM
156+
pipe.enable_model_cpu_offload()
157+
158+
prompt = "A robot standing on a mountain top. The sun is setting in the background"
159+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
160+
video = load_video(
161+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
162+
)
163+
output = pipe(
164+
video=video,
165+
prompt=prompt,
166+
negative_prompt=negative_prompt,
167+
height=480,
168+
width=512,
169+
guidance_scale=7.0,
170+
strength=0.7,
171+
).frames[0]
172+
173+
export_to_video(output, "wan-v2v.mp4", fps=16)
174+
```
175+
136176
## Memory Optimizations for Wan 2.1
137177

138178
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
@@ -323,7 +363,7 @@ import numpy as np
323363
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
324364
from diffusers.hooks.group_offloading import apply_group_offloading
325365
from diffusers.utils import export_to_video, load_image
326-
from transformers import UMT5EncoderModel, CLIPVisionMode
366+
from transformers import UMT5EncoderModel, CLIPVisionModel
327367

328368
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
329369
image_encoder = CLIPVisionModel.from_pretrained(
@@ -356,7 +396,7 @@ prompt = (
356396
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
357397
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
358398
)
359-
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
399+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
360400
num_frames = 33
361401

362402
output = pipe(
@@ -372,7 +412,7 @@ output = pipe(
372412
export_to_video(output, "wan-i2v.mp4", fps=16)
373413
```
374414

375-
### Using a Custom Scheduler
415+
## Using a Custom Scheduler
376416

377417
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
378418

@@ -403,7 +443,7 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
403443
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
404444
```
405445

406-
## Recommendations for Inference:
446+
## Recommendations for Inference
407447
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
408448
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
409449
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@
509509
"VQDiffusionPipeline",
510510
"WanImageToVideoPipeline",
511511
"WanPipeline",
512+
"WanVideoToVideoPipeline",
512513
"WuerstchenCombinedPipeline",
513514
"WuerstchenDecoderPipeline",
514515
"WuerstchenPriorPipeline",
@@ -1062,6 +1063,7 @@
10621063
VQDiffusionPipeline,
10631064
WanImageToVideoPipeline,
10641065
WanPipeline,
1066+
WanVideoToVideoPipeline,
10651067
WuerstchenCombinedPipeline,
10661068
WuerstchenDecoderPipeline,
10671069
WuerstchenPriorPipeline,

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
self.width_pad = width_pad
106106
self.time_pad = time_pad
107107
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
108+
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
108109

109110
self.temporal_dim = 2
110111
self.time_kernel_size = time_kernel_size
@@ -117,6 +118,8 @@ def __init__(
117118
kernel_size=kernel_size,
118119
stride=stride,
119120
dilation=dilation,
121+
padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
122+
padding_mode="zeros",
120123
)
121124

122125
def fake_context_parallel_forward(
@@ -137,9 +140,7 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non
137140
if self.pad_mode == "replicate":
138141
conv_cache = None
139142
else:
140-
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
141143
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
142-
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143144

144145
output = self.conv(inputs)
145146
return output, conv_cache

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,10 @@ def save_pretrained(
714714
if safe_serialization:
715715
# At some point we will need to deal better with save_function (used for TPU and other distributed
716716
# joyfulness), but for now this enough.
717-
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
717+
try:
718+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
719+
except RuntimeError:
720+
safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"})
718721
else:
719722
torch.save(shard, filepath)
720723

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@
356356
"WuerstchenDecoderPipeline",
357357
"WuerstchenPriorPipeline",
358358
]
359-
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
359+
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"]
360360
try:
361361
if not is_onnx_available():
362362
raise OptionalDependencyNotAvailable()
@@ -709,7 +709,7 @@
709709
UniDiffuserPipeline,
710710
UniDiffuserTextDecoder,
711711
)
712-
from .wan import WanImageToVideoPipeline, WanPipeline
712+
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
713713
from .wuerstchen import (
714714
WuerstchenCombinedPipeline,
715715
WuerstchenDecoderPipeline,

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
592592
loaded_sub_model = passed_class_obj[name]
593593

594594
else:
595+
sub_model_dtype = (
596+
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
597+
if isinstance(torch_dtype, dict)
598+
else torch_dtype
599+
)
595600
loaded_sub_model = _load_empty_model(
596601
library_name=library_name,
597602
class_name=class_name,
@@ -600,7 +605,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
600605
is_pipeline_module=is_pipeline_module,
601606
pipeline_class=pipeline_class,
602607
name=name,
603-
torch_dtype=torch_dtype,
608+
torch_dtype=sub_model_dtype,
604609
cached_folder=kwargs.get("cached_folder", None),
605610
force_download=kwargs.get("force_download", None),
606611
proxies=kwargs.get("proxies", None),
@@ -616,7 +621,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
616621
# Obtain a sorted dictionary for mapping the model-level components
617622
# to their sizes.
618623
module_sizes = {
619-
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
624+
module_name: compute_module_sizes(
625+
module,
626+
dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
627+
if isinstance(torch_dtype, dict)
628+
else torch_dtype,
629+
)[""]
620630
for module_name, module in init_empty_modules.items()
621631
if isinstance(module, torch.nn.Module)
622632
}

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
552552
saved using
553553
[`~DiffusionPipeline.save_pretrained`].
554554
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
555-
torch_dtype (`str` or `torch.dtype`, *optional*):
555+
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
556556
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
557-
dtype is automatically derived from the model's weights.
557+
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
558+
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
559+
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
560+
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
558561
custom_pipeline (`str`, *optional*):
559562
560563
<Tip warning={true}>
@@ -703,7 +706,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
703706
use_onnx = kwargs.pop("use_onnx", None)
704707
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
705708

706-
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
709+
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
707710
torch_dtype = torch.float32
708711
logger.warning(
709712
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
@@ -950,14 +953,19 @@ def load_module(name, value):
950953
loaded_sub_model = passed_class_obj[name]
951954
else:
952955
# load sub model
956+
sub_model_dtype = (
957+
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
958+
if isinstance(torch_dtype, dict)
959+
else torch_dtype
960+
)
953961
loaded_sub_model = load_sub_model(
954962
library_name=library_name,
955963
class_name=class_name,
956964
importable_classes=importable_classes,
957965
pipelines=pipelines,
958966
is_pipeline_module=is_pipeline_module,
959967
pipeline_class=pipeline_class,
960-
torch_dtype=torch_dtype,
968+
torch_dtype=sub_model_dtype,
961969
provider=provider,
962970
sess_options=sess_options,
963971
device_map=current_device_map,
@@ -998,7 +1006,7 @@ def load_module(name, value):
9981006
for module in missing_modules:
9991007
init_kwargs[module] = passed_class_obj.get(module, None)
10001008
elif len(missing_modules) > 0:
1001-
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
1009+
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - set(optional_kwargs)
10021010
raise ValueError(
10031011
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
10041012
)

src/diffusers/pipelines/wan/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
else:
2525
_import_structure["pipeline_wan"] = ["WanPipeline"]
2626
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
27-
27+
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
2828
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2929
try:
3030
if not (is_transformers_available() and is_torch_available()):
@@ -35,6 +35,7 @@
3535
else:
3636
from .pipeline_wan import WanPipeline
3737
from .pipeline_wan_i2v import WanImageToVideoPipeline
38+
from .pipeline_wan_video2video import WanVideoToVideoPipeline
3839

3940
else:
4041
import sys

0 commit comments

Comments
 (0)