|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import torch |
| 16 | +from diffusers import WanTransformer3DModel |
| 17 | +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry |
| 18 | +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge |
| 19 | +from megatron.bridge.models.conversion.param_mapping import ( |
| 20 | + AutoMapping, |
| 21 | + KVMapping, |
| 22 | + QKVMapping, |
| 23 | + ReplicatedMapping, |
| 24 | +) |
| 25 | +from megatron.bridge.models.conversion.utils import get_module_and_param_from_name |
| 26 | + |
| 27 | +from dfm.src.megatron.model.wan.conversion.wan_hf_pretrained import PreTrainedWAN |
| 28 | +from dfm.src.megatron.model.wan.wan_model import WanModel |
| 29 | +from dfm.src.megatron.model.wan.wan_provider import WanModelProvider |
| 30 | + |
| 31 | + |
| 32 | +@MegatronModelBridge.register_bridge(source=WanTransformer3DModel, target=WanModel) |
| 33 | +class WanBridge(MegatronModelBridge): |
| 34 | + """ |
| 35 | + Megatron Bridge for WAN model. |
| 36 | +
|
| 37 | + As a user you would not use this bridge directly, but through `AutoBridge`. |
| 38 | +
|
| 39 | + Example: |
| 40 | + >>> from megatron.bridge import AutoBridge |
| 41 | + >>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1") |
| 42 | + >>> provider = bridge.to_megatron_provider() |
| 43 | + """ |
| 44 | + |
| 45 | + def provider_bridge(self, hf_pretrained: PreTrainedWAN) -> WanModelProvider: |
| 46 | + hf_config = hf_pretrained.config |
| 47 | + |
| 48 | + cls = WanModelProvider |
| 49 | + |
| 50 | + provider = cls( |
| 51 | + num_layers=hf_config.num_layers, |
| 52 | + hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim, |
| 53 | + kv_channels=hf_config.attention_head_dim, |
| 54 | + num_query_groups=hf_config.num_attention_heads, |
| 55 | + crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim, |
| 56 | + ffn_hidden_size=hf_config.ffn_dim, |
| 57 | + num_attention_heads=hf_config.num_attention_heads, |
| 58 | + in_channels=hf_config.in_channels, |
| 59 | + out_channels=hf_config.out_channels, |
| 60 | + text_dim=hf_config.text_dim, |
| 61 | + patch_spatial=hf_config.patch_size[1], |
| 62 | + patch_temporal=hf_config.patch_size[0], |
| 63 | + layernorm_epsilon=hf_config.eps, |
| 64 | + hidden_dropout=0, |
| 65 | + attention_dropout=0, |
| 66 | + use_cpu_initialization=True, |
| 67 | + freq_dim=hf_config.freq_dim, |
| 68 | + bf16=False, |
| 69 | + params_dtype=torch.float32, |
| 70 | + ) |
| 71 | + |
| 72 | + return provider |
| 73 | + |
| 74 | + def mapping_registry(self) -> MegatronMappingRegistry: |
| 75 | + """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. |
| 76 | +
|
| 77 | + Returns: |
| 78 | + MegatronMappingRegistry: Registry of parameter mappings |
| 79 | + """ |
| 80 | + # Dictionary maps HF parameter names -> Megatron parameter names |
| 81 | + # Supports wildcard (*) patterns for layer-specific parameters |
| 82 | + param_mappings = { |
| 83 | + "scale_shift_table": "head.modulation", |
| 84 | + "patch_embedding.weight": "patch_embedding.weight", |
| 85 | + "patch_embedding.bias": "patch_embedding.bias", |
| 86 | + "condition_embedder.time_embedder.linear_1.weight": "time_embedder.linear_1.weight", |
| 87 | + "condition_embedder.time_embedder.linear_1.bias": "time_embedder.linear_1.bias", |
| 88 | + "condition_embedder.time_embedder.linear_2.weight": "time_embedder.linear_2.weight", |
| 89 | + "condition_embedder.time_embedder.linear_2.bias": "time_embedder.linear_2.bias", |
| 90 | + "condition_embedder.time_proj.weight": "time_proj.weight", |
| 91 | + "condition_embedder.time_proj.bias": "time_proj.bias", |
| 92 | + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", |
| 93 | + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", |
| 94 | + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", |
| 95 | + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", |
| 96 | + "blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation", |
| 97 | + "blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight", |
| 98 | + "blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias", |
| 99 | + "blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight", |
| 100 | + "blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight", |
| 101 | + "blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight", |
| 102 | + "blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias", |
| 103 | + "blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight", |
| 104 | + "blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias", |
| 105 | + "blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight", |
| 106 | + "blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight", |
| 107 | + "blocks.*.norm2.weight": "decoder.layers.*.norm3.weight", |
| 108 | + "blocks.*.norm2.bias": "decoder.layers.*.norm3.bias", |
| 109 | + "blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", |
| 110 | + "blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias", |
| 111 | + "blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight", |
| 112 | + "blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias", |
| 113 | + "proj_out.weight": "head.head.weight", |
| 114 | + "proj_out.bias": "head.head.bias", |
| 115 | + } |
| 116 | + |
| 117 | + # Custom WAN mapping to safely handle replicated params whose owning module |
| 118 | + # does not expose a top-level `.weight` (e.g., Head.modulation) |
| 119 | + class _ReplicatedByParamNameMapping(ReplicatedMapping): |
| 120 | + def hf_to_megatron(self, hf_weights, megatron_module): |
| 121 | + normalized_param = self._normalize_expert_param_name(self.megatron_param) |
| 122 | + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) |
| 123 | + |
| 124 | + target_device = target_param.device |
| 125 | + target_dtype = target_param.dtype |
| 126 | + |
| 127 | + hf_weights = hf_weights.to(device=target_device, dtype=target_dtype) |
| 128 | + if self.tp_size == 1: |
| 129 | + return hf_weights |
| 130 | + |
| 131 | + if target_device.type == "cuda" and torch.cuda.is_available(): |
| 132 | + if target_device.index != torch.cuda.current_device(): |
| 133 | + hf_weights = hf_weights.to(torch.cuda.current_device()) |
| 134 | + |
| 135 | + if self.tp_rank > 0: |
| 136 | + hf_weights = torch.empty_like(hf_weights) |
| 137 | + |
| 138 | + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) |
| 139 | + |
| 140 | + mapping_list = [] |
| 141 | + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) |
| 142 | + for hf_param, megatron_param in param_mappings.items(): |
| 143 | + if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}: |
| 144 | + # Use WAN-specific replicated mapping that resolves the exact param |
| 145 | + mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param)) |
| 146 | + else: |
| 147 | + mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param)) |
| 148 | + |
| 149 | + # Adding custom module types for AutoMapping |
| 150 | + AutoMapping.register_module_type("Linear", "replicated") |
| 151 | + AutoMapping.register_module_type("Conv3d", "replicated") |
| 152 | + AutoMapping.register_module_type("WanAdaLN", "replicated") |
| 153 | + AutoMapping.register_module_type("Head", "replicated") |
| 154 | + |
| 155 | + # Add special mappings that require parameter concatenation/transformation |
| 156 | + mapping_list.extend( |
| 157 | + [ |
| 158 | + # QKV: Combine separate Q, K, V matrices into single QKV matrix |
| 159 | + QKVMapping( |
| 160 | + q="blocks.*.attn1.to_q.weight", |
| 161 | + k="blocks.*.attn1.to_k.weight", |
| 162 | + v="blocks.*.attn1.to_v.weight", |
| 163 | + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight", |
| 164 | + ), |
| 165 | + # QKV bias: Combine separate Q, K, V bias into single QKV bias |
| 166 | + QKVMapping( |
| 167 | + q="blocks.*.attn1.to_q.bias", |
| 168 | + k="blocks.*.attn1.to_k.bias", |
| 169 | + v="blocks.*.attn1.to_v.bias", |
| 170 | + megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias", |
| 171 | + ), |
| 172 | + # K, V: Combine separate K, V matrices into single KV matrix |
| 173 | + KVMapping( |
| 174 | + k="blocks.*.attn2.to_k.weight", |
| 175 | + v="blocks.*.attn2.to_v.weight", |
| 176 | + megatron_param="decoder.layers.*.cross_attention.linear_kv.weight", |
| 177 | + ), |
| 178 | + # K, V bias: Combine separate K, V bias into single KV bias |
| 179 | + KVMapping( |
| 180 | + k="blocks.*.attn2.to_k.bias", |
| 181 | + v="blocks.*.attn2.to_v.bias", |
| 182 | + megatron_param="decoder.layers.*.cross_attention.linear_kv.bias", |
| 183 | + ), |
| 184 | + ] |
| 185 | + ) |
| 186 | + |
| 187 | + return MegatronMappingRegistry(*mapping_list) |
0 commit comments