diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index dcc1b2b55e30..b16bf92a6370 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -45,6 +45,22 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler pipe.scheduler = ``` +### Using single file loading with Wan + +The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading +method. + + +```python +import torch +from diffusers import WanPipeline, WanTransformer3DModel + +ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors" +transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) + +pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer) +``` + ## WanPipeline [[autodoc]] WanPipeline diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index e6b050833485..b7d61b3e8ff4 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -39,6 +39,8 @@ convert_mochi_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, + convert_wan_transformer_to_diffusers, + convert_wan_vae_to_diffusers, create_controlnet_diffusers_config_from_ldm, create_unet_diffusers_config_from_ldm, create_vae_diffusers_config_from_ldm, @@ -117,6 +119,14 @@ "checkpoint_mapping_fn": convert_lumina2_to_diffusers, "default_subfolder": "transformer", }, + "WanTransformer3DModel": { + "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, + "default_subfolder": "transformer", + }, + "AutoencoderKLWan": { + "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, + "default_subfolder": "vae", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d16c418b290b..8ee7e14cb101 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -117,6 +117,8 @@ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], + "wan": ["model.diffusion_model.head.modulation", "head.modulation"], + "wan_vae": "decoder.middle.0.residual.0.gamma", } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -176,6 +178,9 @@ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, + "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, + "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, + "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, } # Use to configure model sample size when original config is provided @@ -664,6 +669,21 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): model_type = "lumina2" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]): + if "model.diffusion_model.patch_embedding.weight" in checkpoint: + target_key = "model.diffusion_model.patch_embedding.weight" + else: + target_key = "patch_embedding.weight" + + if checkpoint[target_key].shape[0] == 1536: + model_type = "wan-t2v-1.3B" + elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16: + model_type = "wan-t2v-14B" + else: + model_type = "wan-i2v-14B" + elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint: + # All Wan models use the same VAE so we can use the same default model repo to fetch the config + model_type = "wan-t2v-14B" else: model_type = "v1" @@ -2470,7 +2490,7 @@ def remap_proj_conv_(key: str, state_dict): def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): - new_state_dict = {} + converted_state_dict = {} # Comfy checkpoints add this prefix keys = list(checkpoint.keys()) @@ -2479,22 +2499,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) # Convert patch_embed - new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") # Convert time_embed - new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") - new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") - new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") - new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") - new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") - new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") - new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") - new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") - new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") - new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") + converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") + converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") + converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") + converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") + converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") + converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") + converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") # Convert transformer blocks num_layers = 48 @@ -2503,68 +2523,84 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): old_prefix = f"blocks.{i}." # norm1 - new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") - new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") + converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") + converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") if i < num_layers - 1: - new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight") - new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop( + old_prefix + "mod_y.bias" + ) else: - new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( + converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( old_prefix + "mod_y.weight" ) - new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop( + old_prefix + "mod_y.bias" + ) # Visual attention qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") q, k, v = qkv_weight.chunk(3, dim=0) - new_state_dict[block_prefix + "attn1.to_q.weight"] = q - new_state_dict[block_prefix + "attn1.to_k.weight"] = k - new_state_dict[block_prefix + "attn1.to_v.weight"] = v - new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight") - new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight") - new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight") - new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") + converted_state_dict[block_prefix + "attn1.to_q.weight"] = q + converted_state_dict[block_prefix + "attn1.to_k.weight"] = k + converted_state_dict[block_prefix + "attn1.to_v.weight"] = v + converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_x.weight" + ) + converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_x.weight" + ) + converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop( + old_prefix + "attn.proj_x.weight" + ) + converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") # Context attention qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") q, k, v = qkv_weight.chunk(3, dim=0) - new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q - new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k - new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v - new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( + converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( old_prefix + "attn.q_norm_y.weight" ) - new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( + converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( old_prefix + "attn.k_norm_y.weight" ) if i < num_layers - 1: - new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( + converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( old_prefix + "attn.proj_y.weight" ) - new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias") + converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop( + old_prefix + "attn.proj_y.bias" + ) # MLP - new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( checkpoint.pop(old_prefix + "mlp_x.w1.weight") ) - new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") + converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") if i < num_layers - 1: - new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( checkpoint.pop(old_prefix + "mlp_y.w1.weight") ) - new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight") + converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop( + old_prefix + "mlp_y.w2.weight" + ) # Output layers - new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) - new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) - new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") + converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") - return new_state_dict + return converted_state_dict def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs): @@ -2859,3 +2895,252 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key): converted_state_dict[diffusers_key] = checkpoint.pop(key) return converted_state_dict + + +def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "cross_attn": "attn2", + "self_attn": "attn1", + ".o.": ".to_out.0.", + ".q.": ".to_q.", + ".k.": ".to_k.", + ".v.": ".to_v.", + ".k_img.": ".add_k_proj.", + ".v_img.": ".add_v_proj.", + ".norm_k_img.": ".norm_added_k.", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + } + + for key in list(checkpoint.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + converted_state_dict[new_key] = checkpoint.pop(key) + + return converted_state_dict + + +def convert_wan_vae_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in checkpoint.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + converted_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + converted_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + converted_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + converted_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + converted_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + converted_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + converted_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + converted_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + converted_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + converted_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + converted_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + converted_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + converted_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + converted_state_dict[new_key] = value + else: + # Keep other keys unchanged + converted_state_dict[key] = value + + return converted_state_dict diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 819a1d6ba390..b45cb2a7950d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -284,8 +284,9 @@ def __init__( self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) elif qk_norm == "rms_norm_across_heads": - # Wanx applies qk norm across all heads - self.norm_added_q = RMSNorm(dim_head * heads, eps=eps) + # Wan applies qk norm across all heads + # Wan also doesn't apply a q norm + self.norm_added_q = None self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps) else: raise ValueError( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 513afa3dfaee..fdc7b3ffe58a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation @@ -655,7 +656,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x -class AutoencoderKLWan(ModelMixin, ConfigMixin): +class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced in [Wan 2.1]. diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 259afa547bc5..66cdda388c06 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention @@ -288,7 +288,7 @@ def forward( return hidden_states -class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data used in the Wan model. @@ -329,6 +329,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config def __init__( diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py new file mode 100644 index 000000000000..f5720ddd3964 --- /dev/null +++ b/tests/single_file/test_model_wan_autoencoder_single_file.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from diffusers import ( + AutoencoderKLWan, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class AutoencoderKLWanSingleFileTests(unittest.TestCase): + model_class = AutoencoderKLWan + ckpt_path = ( + "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors" + ) + repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="vae") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py new file mode 100644 index 000000000000..9b938aa1754c --- /dev/null +++ b/tests/single_file/test_model_wan_transformer3d_single_file.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import ( + WanTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_big_gpu_with_torch_cuda, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase): + model_class = WanTransformer3DModel + ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors" + repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" + + +@require_big_gpu_with_torch_cuda +@require_torch_accelerator +class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase): + model_class = WanTransformer3DModel + ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors" + repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + torch_dtype = torch.float8_e4m3fn + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype) + model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading"