diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 9f76be91339a..919268b0b558 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -314,6 +314,8 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
+ - local: api/models/wan_transformer_3d
+ title: WanTransformer3DModel
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
@@ -344,6 +346,8 @@
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
+ - local: api/models/autoencoder_kl_wan
+ title: AutoencoderKLWan
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
@@ -534,6 +538,8 @@
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
+ - local: api/pipelines/wan
+ title: Wan
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Pipelines
diff --git a/docs/source/en/api/models/autoencoder_kl_wan.md b/docs/source/en/api/models/autoencoder_kl_wan.md
new file mode 100644
index 000000000000..43165c8edf7a
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_wan.md
@@ -0,0 +1,32 @@
+
+
+# AutoencoderKLWan
+
+The 3D variational autoencoder (VAE) model with KL loss used in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLWan
+
+vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
+```
+
+## AutoencoderKLWan
+
+[[autodoc]] AutoencoderKLWan
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/wan_transformer_3d.md b/docs/source/en/api/models/wan_transformer_3d.md
new file mode 100644
index 000000000000..56015c4c07f1
--- /dev/null
+++ b/docs/source/en/api/models/wan_transformer_3d.md
@@ -0,0 +1,30 @@
+
+
+# WanTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import WanTransformer3DModel
+
+transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## WanTransformer3DModel
+
+[[autodoc]] WanTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
new file mode 100644
index 000000000000..dcc1b2b55e30
--- /dev/null
+++ b/docs/source/en/api/pipelines/wan.md
@@ -0,0 +1,62 @@
+
+
+# Wan
+
+[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
+
+
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+Recommendations for inference:
+- VAE in `torch.float32` for better decoding quality.
+- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
+- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
+
+### Using a custom scheduler
+
+Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
+
+```python
+from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline
+
+scheduler_a = FlowMatchEulerDiscreteScheduler(shift=5.0)
+scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=4.0)
+
+pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler=)
+
+# or,
+pipe.scheduler =
+```
+
+## WanPipeline
+
+[[autodoc]] WanPipeline
+ - all
+ - __call__
+
+## WanImageToVideoPipeline
+
+[[autodoc]] WanImageToVideoPipeline
+ - all
+ - __call__
+
+## WanPipelineOutput
+
+[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py
new file mode 100644
index 000000000000..0b2fa872487e
--- /dev/null
+++ b/scripts/convert_wan_to_diffusers.py
@@ -0,0 +1,423 @@
+import argparse
+import pathlib
+from typing import Any, Dict
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download, snapshot_download
+from safetensors.torch import load_file
+from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ UniPCMultistepScheduler,
+ WanImageToVideoPipeline,
+ WanPipeline,
+ WanTransformer3DModel,
+)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # For the I2V model
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+
+
+def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def load_sharded_safetensors(dir: pathlib.Path):
+ file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
+ state_dict = {}
+ for path in file_paths:
+ state_dict.update(load_file(path))
+ return state_dict
+
+
+def get_transformer_config(model_type: str) -> Dict[str, Any]:
+ if model_type == "Wan-T2V-1.3B":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 12,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "Wan-T2V-14B":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "Wan-I2V-14B-480p":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
+ "diffusers_config": {
+ "image_dim": 1280,
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "Wan-I2V-14B-720p":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
+ "diffusers_config": {
+ "image_dim": 1280,
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ return config
+
+
+def convert_transformer(model_type: str):
+ config = get_transformer_config(model_type)
+ diffusers_config = config["diffusers_config"]
+ model_id = config["model_id"]
+ model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
+
+ original_state_dict = load_sharded_safetensors(model_dir)
+
+ with init_empty_weights():
+ transformer = WanTransformer3DModel.from_config(diffusers_config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_vae():
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
+ new_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in old_state_dict.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ new_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ new_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ new_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ new_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Convert to down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Convert residual block naming but keep the original structure
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+
+ new_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Convert to up_blocks
+ parts = key.split(".")
+ block_idx = int(parts[2])
+
+ # Group residual blocks
+ if "residual" in key:
+ if block_idx in [0, 1, 2]:
+ new_block_idx = 0
+ resnet_idx = block_idx
+ elif block_idx in [4, 5, 6]:
+ new_block_idx = 1
+ resnet_idx = block_idx - 4
+ elif block_idx in [8, 9, 10]:
+ new_block_idx = 2
+ resnet_idx = block_idx - 8
+ elif block_idx in [12, 13, 14]:
+ new_block_idx = 3
+ resnet_idx = block_idx - 12
+ else:
+ # Keep as is for other blocks
+ new_state_dict[key] = value
+ continue
+
+ # Convert residual block naming
+ if ".residual.0.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
+ elif ".residual.2.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
+ elif ".residual.2.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
+ elif ".residual.3.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
+ elif ".residual.6.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
+ elif ".residual.6.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
+ else:
+ new_key = key
+
+ new_state_dict[new_key] = value
+
+ # Handle shortcut connections
+ elif ".shortcut." in key:
+ if block_idx == 4:
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
+
+ new_state_dict[new_key] = value
+
+ # Handle upsamplers
+ elif ".resample." in key or ".time_conv." in key:
+ if block_idx == 3:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
+ elif block_idx == 7:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
+ elif block_idx == 11:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ new_state_dict[new_key] = value
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ new_state_dict[key] = value
+
+ with init_empty_weights():
+ vae = AutoencoderKLWan()
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default=None)
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--dtype", default="fp32")
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ dtype = DTYPE_MAPPING[args.dtype]
+
+ transformer = convert_transformer(args.model_type).to(dtype=dtype)
+ vae = convert_vae()
+ text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ scheduler = UniPCMultistepScheduler(
+ prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
+ )
+
+ if "I2V" in args.model_type:
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
+ )
+ image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ pipe = WanImageToVideoPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+ else:
+ pipe = WanPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 71dd49886f6f..6262ab802de0 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -96,6 +96,7 @@
"AutoencoderKLLTXVideo",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
+ "AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderTiny",
"CacheMixin",
@@ -148,6 +149,7 @@
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
+ "WanTransformer3DModel",
]
)
_import_structure["optimization"] = [
@@ -438,6 +440,8 @@
"VersatileDiffusionTextToImagePipeline",
"VideoToVideoSDPipeline",
"VQDiffusionPipeline",
+ "WanImageToVideoPipeline",
+ "WanPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
@@ -618,6 +622,7 @@
AutoencoderKLLTXVideo,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
+ AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
CacheMixin,
@@ -669,6 +674,7 @@
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
+ WanTransformer3DModel,
)
from .optimization import (
get_constant_schedule,
@@ -938,6 +944,8 @@
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
+ WanImageToVideoPipeline,
+ WanPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 853f149fe01c..60b9f1e230f2 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -35,6 +35,7 @@
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
+ _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
@@ -79,6 +80,7 @@
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
+ _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -109,6 +111,7 @@
AutoencoderKLLTXVideo,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
+ AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
@@ -158,6 +161,7 @@
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
+ WanTransformer3DModel,
)
from .unets import (
I2VGenXLUNet,
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index fe126c46dfef..b19851aa3e7c 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -280,6 +280,10 @@ def __init__(
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ elif qk_norm == "rms_norm_across_heads":
+ # Wanx applies qk norm across all heads
+ self.norm_added_q = RMSNorm(dim_head * heads, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index bb750a4410f2..f1cbbdf8a10d 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -7,6 +7,7 @@
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
+from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
new file mode 100644
index 000000000000..513afa3dfaee
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
@@ -0,0 +1,865 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+CACHE_T = 2
+
+
+class WanCausalConv3d(nn.Conv3d):
+ r"""
+ A custom 3D causal convolution layer with feature caching support.
+
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
+ caching for efficient inference.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ ) -> None:
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+
+ # Set up causal padding
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+ return super().forward(x)
+
+
+class WanRMS_norm(nn.Module):
+ r"""
+ A custom RMS normalization layer.
+
+ Args:
+ dim (int): The number of dimensions to normalize over.
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
+ Default is True.
+ images (bool, optional): Whether the input represents image data. Default is True.
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
+ """
+
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class WanUpsample(nn.Upsample):
+ r"""
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
+
+ Args:
+ x (torch.Tensor): Input tensor to be upsampled.
+
+ Returns:
+ torch.Tensor: Upsampled tensor with the same data type as the input.
+ """
+
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+class WanResample(nn.Module):
+ r"""
+ A custom resampling module for 2D and 3D data.
+
+ Args:
+ dim (int): The number of input/output channels.
+ mode (str): The resampling mode. Must be one of:
+ - 'none': No resampling (identity operation).
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
+ """
+
+ def __init__(self, dim: int, mode: str) -> None:
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == "upsample2d":
+ self.resample = nn.Sequential(
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ )
+ elif mode == "upsample3d":
+ self.resample = nn.Sequential(
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ )
+ self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == "downsample2d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == "downsample3d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == "upsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = "Rep"
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
+ # cache last frame of last two chunk
+ cache_x = torch.cat(
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
+ )
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
+ if feat_cache[idx] == "Rep":
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = self.resample(x)
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
+
+ if self.mode == "downsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -1:, :, :].clone()
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+
+class WanResidualBlock(nn.Module):
+ r"""
+ A custom residual block module.
+
+ Args:
+ in_dim (int): Number of input channels.
+ out_dim (int): Number of output channels.
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ dropout: float = 0.0,
+ non_linearity: str = "silu",
+ ) -> None:
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.nonlinearity = get_activation(non_linearity)
+
+ # layers
+ self.norm1 = WanRMS_norm(in_dim, images=False)
+ self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
+ self.norm2 = WanRMS_norm(out_dim, images=False)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
+ self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # Apply shortcut connection
+ h = self.conv_shortcut(x)
+
+ # First normalization and activation
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ # Second normalization and activation
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ # Dropout
+ x = self.dropout(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv2(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv2(x)
+
+ # Add residual connection
+ return x + h
+
+
+class WanAttentionBlock(nn.Module):
+ r"""
+ Causal self-attention with a single head.
+
+ Args:
+ dim (int): The number of channels in the input tensor.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = WanRMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x):
+ identity = x
+ batch_size, channels, time, height, width = x.size()
+
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
+ x = self.norm(x)
+
+ # compute query, key, value
+ qkv = self.to_qkv(x)
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
+ q, k, v = qkv.chunk(3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(q, k, v)
+
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
+
+ # output projection
+ x = self.proj(x)
+
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
+ x = x.view(batch_size, time, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4)
+
+ return x + identity
+
+
+class WanMidBlock(nn.Module):
+ """
+ Middle block for WanVAE encoder and decoder.
+
+ Args:
+ dim (int): Number of input/output channels.
+ dropout (float): Dropout rate.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
+ super().__init__()
+ self.dim = dim
+
+ # Create the components
+ resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
+ attentions = []
+ for _ in range(num_layers):
+ attentions.append(WanAttentionBlock(dim))
+ resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # First residual block
+ x = self.resnets[0](x, feat_cache, feat_idx)
+
+ # Process through attention and residual blocks
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ x = attn(x)
+
+ x = resnet(x, feat_cache, feat_idx)
+
+ return x
+
+
+class WanEncoder3d(nn.Module):
+ r"""
+ A 3D encoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.nonlinearity = get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ self.down_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ self.down_blocks.append(WanAttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
+ scale /= 2.0
+
+ # middle blocks
+ self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
+
+ # output blocks
+ self.norm_out = WanRMS_norm(out_dim, images=False)
+ self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## downsamples
+ for layer in self.down_blocks:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class WanUpBlock(nn.Module):
+ """
+ A block that handles upsampling for the WanVAE decoder.
+
+ Args:
+ in_dim (int): Input dimension
+ out_dim (int): Output dimension
+ num_res_blocks (int): Number of residual blocks
+ dropout (float): Dropout rate
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
+ non_linearity (str): Type of non-linearity to use
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ num_res_blocks: int,
+ dropout: float = 0.0,
+ upsample_mode: Optional[str] = None,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # Create layers list
+ resnets = []
+ # Add residual blocks and attention if needed
+ current_dim = in_dim
+ for _ in range(num_res_blocks + 1):
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
+ current_dim = out_dim
+
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add upsampling layer if needed
+ self.upsamplers = None
+ if upsample_mode is not None:
+ self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ """
+ Forward pass through the upsampling block.
+
+ Args:
+ x (torch.Tensor): Input tensor
+ feat_cache (list, optional): Feature cache for causal convolutions
+ feat_idx (list, optional): Feature index for cache management
+
+ Returns:
+ torch.Tensor: Output tensor
+ """
+ for resnet in self.resnets:
+ if feat_cache is not None:
+ x = resnet(x, feat_cache, feat_idx)
+ else:
+ x = resnet(x)
+
+ if self.upsamplers is not None:
+ if feat_cache is not None:
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
+ else:
+ x = self.upsamplers[0](x)
+ return x
+
+
+class WanDecoder3d(nn.Module):
+ r"""
+ A 3D decoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
+
+ # init block
+ self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
+
+ # upsample blocks
+ self.up_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i > 0:
+ in_dim = in_dim // 2
+
+ # Determine if we need upsampling
+ upsample_mode = None
+ if i != len(dim_mult) - 1:
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
+
+ # Create and add the upsampling block
+ up_block = WanUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ upsample_mode=upsample_mode,
+ non_linearity=non_linearity,
+ )
+ self.up_blocks.append(up_block)
+
+ # Update scale for next iteration
+ if upsample_mode is not None:
+ scale *= 2.0
+
+ # output blocks
+ self.norm_out = WanRMS_norm(out_dim, images=False)
+ self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## upsamples
+ for up_block in self.up_blocks:
+ x = up_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class AutoencoderKLWan(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
+ Introduced in [Wan 2.1].
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+
+ @register_to_config
+ def __init__(
+ self,
+ base_dim: int = 96,
+ z_dim: int = 16,
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
+ num_res_blocks: int = 2,
+ attn_scales: List[float] = [],
+ temperal_downsample: List[bool] = [False, True, True],
+ dropout: float = 0.0,
+ latents_mean: List[float] = [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ ],
+ latents_std: List[float] = [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ ],
+ ) -> None:
+ super().__init__()
+
+ # Store normalization parameters as tensors
+ self.mean = torch.tensor(latents_mean)
+ self.std = torch.tensor(latents_std)
+ self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
+
+ self.z_dim = z_dim
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ self.encoder = WanEncoder3d(
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
+ )
+ self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
+
+ self.decoder = WanDecoder3d(
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
+ )
+
+ def clear_cache(self):
+ def _count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, WanCausalConv3d):
+ count += 1
+ return count
+
+ self._conv_num = _count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ # cache encode
+ self._enc_conv_num = _count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ scale = self.scale.type_as(x)
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx,
+ )
+ out = torch.cat([out, out_], 2)
+
+ enc = self.quant_conv(out)
+ mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
+ logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
+ enc = torch.cat([mu, logvar], dim=1)
+ self.clear_cache()
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ self.clear_cache()
+ # z: [b,c,t,h,w]
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
+
+ iter_ = z.shape[2]
+ x = self.post_quant_conv(z)
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+
+ out = torch.clamp(out, min=-1.0, max=1.0)
+ self.clear_cache()
+ if not return_dict:
+ return (out,)
+
+ return DecoderOutput(sample=out)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ scale = self.scale.type_as(z)
+ decoded = self._decode(z, scale).sample
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 4fbbd78667e3..6983940f139b 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -166,8 +166,12 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
last_dtype = None
- for param in parameter.parameters():
+
+ for name, param in parameter.named_parameters():
last_dtype = param.dtype
+ if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
+ continue
+
if param.is_floating_point():
return param.dtype
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index f32c30ceff3c..ee317051dff9 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -27,3 +27,4 @@
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
+ from .transformer_wan import WanTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
new file mode 100644
index 000000000000..33e9daf70fe4
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -0,0 +1,438 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanAttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ encoder_hidden_states_img = encoder_hidden_states[:, :257]
+ encoder_hidden_states = encoder_hidden_states[:, 257:]
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, rotary_emb)
+ key = apply_rotary_emb(key, rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ hidden_states_img = F.scaled_dot_product_attention(
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class WanImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int):
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = nn.LayerNorm(out_features)
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class WanTimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+
+ freqs = []
+ for dim in [t_dim, h_dim, w_dim]:
+ freq = get_1d_rotary_pos_embed(
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
+ )
+ freqs.append(freq)
+ self.freqs = torch.cat(freqs, dim=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ self.freqs = self.freqs.to(hidden_states.device)
+ freqs = self.freqs.split_with_sizes(
+ [
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
+ self.attention_head_dim // 6,
+ self.attention_head_dim // 6,
+ ],
+ dim=1,
+ )
+
+ freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+ return freqs
+
+
+class WanTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ processor=WanAttnProcessor2_0(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ added_kv_proj_dim=added_kv_proj_dim,
+ added_proj_bias=True,
+ processor=WanAttnProcessor2_0(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class WanTransformer3DModel(ModelMixin, ConfigMixin):
+ r"""
+ A Transformer model for video-like data used in the Wan model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ add_img_emb (`bool`, defaults to `False`):
+ Whether to use img_emb.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanTransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ # image_embedding_dim=1280 for I2V model
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanTransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image
+ )
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # 5. Output norm, projection & unpatchify
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 8e7f9d68a5d4..a15e1db64e4f 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -347,6 +347,7 @@
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
+ _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -690,6 +691,7 @@
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
+ from .wan import WanImageToVideoPipeline, WanPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py
new file mode 100644
index 000000000000..84ec62b577e1
--- /dev/null
+++ b/src/diffusers/pipelines/wan/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_wan"] = ["WanPipeline"]
+ _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_wan import WanPipeline
+ from .pipeline_wan_i2v import WanImageToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/wan/pipeline_output.py b/src/diffusers/pipelines/wan/pipeline_output.py
new file mode 100644
index 000000000000..88907ad0f0a1
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class WanPipelineOutput(BaseOutput):
+ r"""
+ Output class for Wan pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py
new file mode 100644
index 000000000000..062a2c21fd09
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan.py
@@ -0,0 +1,562 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Callable, Dict, List, Optional, Union
+
+import ftfy
+import regex as re
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import AutoencoderKLWan, WanPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
+ >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A cat walks on the grass, realistic"
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=15)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class WanPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: WanTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: 16,
+ height: int = 720,
+ width: int = 1280,
+ num_latent_frames: int = 21,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `129`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ The dtype to use for the torch.amp.autocast.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_latent_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
new file mode 100644
index 000000000000..eff63efe5197
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
@@ -0,0 +1,642 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import ftfy
+import numpy as np
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-1.3B-720P-Diffusers
+ >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> height, width = 480, 832
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... ).resize((width, height))
+ >>> prompt = (
+ ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ ... )
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... image=image, prompt=prompt, negative_prompt=negative_prompt, num_frames=81, guidance_scale=5.0
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=15)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class WanImageToVideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
+ the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ image_encoder: CLIPVisionModel,
+ image_processor: CLIPImageProcessor,
+ transformer: WanTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_image(self, image: PipelineImageInput):
+ image = self.image_processor(images=image, return_tensors="pt").to(self.device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-1]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ max_area,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
+ if max_area < 0:
+ raise ValueError(f"`max_area` has to be positive but are {max_area}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: 32,
+ height: int = 720,
+ width: int = 1280,
+ max_area: int = 720 * 1280,
+ num_frames: int = 81,
+ num_latent_frames: int = 21,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ aspect_ratio = height / width
+ mod_value = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ image = self.video_processor.preprocess(image, height=height, width=width)[:, :, None]
+ video_condition = torch.cat(
+ [image, torch.zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ video_condition = video_condition.to(device=device, dtype=dtype)
+ if isinstance(generator, list):
+ latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
+ latents = latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+ mask_lat_size = torch.ones(
+ batch_size,
+ 1,
+ num_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(
+ batch_size,
+ -1,
+ self.vae_scale_factor_temporal,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
+
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ max_area: int = 720 * 1280,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ max_area (`int`, defaults to `1280 * 720`):
+ The maximum area in pixels of the generated image.
+ num_frames (`int`, defaults to `129`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ shift (`float`, *optional*, defaults to `5.0`):
+ The shift of the flow.
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ The dtype to use for the torch.amp.autocast.
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ max_area,
+ prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ image_embeds = self.encode_image(image)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ if isinstance(image, torch.Tensor):
+ height, width = image.shape[-2:]
+ else:
+ width, height = image.size
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latents, condition = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ max_area,
+ num_frames,
+ num_latent_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 9dd1e690742f..10827978bc99 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -201,6 +201,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLWan(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderOobleck(metaclass=DummyObject):
_backends = ["torch"]
@@ -966,6 +981,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class WanTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 8bb9ec1cb321..1ab4f4ba4f5a 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2597,6 +2597,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class WanImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WanPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class WuerstchenCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py
new file mode 100644
index 000000000000..ffc474039889
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_wan.py
@@ -0,0 +1,79 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLWan
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLWan
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_wan_config(self):
+ return {
+ "base_dim": 3,
+ "z_dim": 16,
+ "dim_mult": [1, 1, 1, 1],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_wan_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skip("Gradient checkpointing has not been implemented yet")
+ def test_gradient_checkpointing_is_applied(self):
+ pass
+
+ @unittest.skip("Test not supported")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_training(self):
+ pass
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index b917efe0850f..8754d2073e35 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -739,8 +739,14 @@ def test_from_save_pretrained_dtype(self):
model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
- new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
- assert new_model.dtype == dtype
+ if (
+ hasattr(self.model_class, "_keep_in_fp32_modules")
+ and self.model_class._keep_in_fp32_modules is None
+ ):
+ new_model = self.model_class.from_pretrained(
+ tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
+ )
+ assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
if self.forward_requires_fresh_args:
diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py
new file mode 100644
index 000000000000..3ac64c628988
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_wan.py
@@ -0,0 +1,81 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import WanTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = WanTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"WanTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/pipelines/wan/__init__.py b/tests/pipelines/wan/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py
new file mode 100644
index 000000000000..a162e6841d2d
--- /dev/null
+++ b/tests/pipelines/wan/test_wan.py
@@ -0,0 +1,156 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+
+@slow
+@require_torch_accelerator
+class WanPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_Wanx(self):
+ pass
diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py
new file mode 100644
index 000000000000..b898545c147b
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_image_to_video.py
@@ -0,0 +1,161 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=32,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=32, size=32)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "max_area": 1024,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+ expected_video = torch.randn(9, 3, 32, 32)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass