Skip to content

Commit 6b3295e

Browse files
huvunvidiaHuy Vu2
andauthored
Wan HF <-> Megatron checkpoints conversion (#73)
* inital commit, workable code * add example * fix lint * fix lint * bring all wan related codes to DFM * add tests * lint --------- Co-authored-by: Huy Vu2 <[email protected]>
1 parent 2489a8e commit 6b3295e

File tree

9 files changed

+960
-2
lines changed

9 files changed

+960
-2
lines changed

3rdparty/Megatron-Bridge

Submodule Megatron-Bridge updated 298 files
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from diffusers import WanTransformer3DModel
17+
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
18+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
19+
from megatron.bridge.models.conversion.param_mapping import (
20+
AutoMapping,
21+
KVMapping,
22+
QKVMapping,
23+
ReplicatedMapping,
24+
)
25+
from megatron.bridge.models.conversion.utils import get_module_and_param_from_name
26+
27+
from dfm.src.megatron.model.wan.conversion.wan_hf_pretrained import PreTrainedWAN
28+
from dfm.src.megatron.model.wan.wan_model import WanModel
29+
from dfm.src.megatron.model.wan.wan_provider import WanModelProvider
30+
31+
32+
@MegatronModelBridge.register_bridge(source=WanTransformer3DModel, target=WanModel)
33+
class WanBridge(MegatronModelBridge):
34+
"""
35+
Megatron Bridge for WAN model.
36+
37+
As a user you would not use this bridge directly, but through `AutoBridge`.
38+
39+
Example:
40+
>>> from megatron.bridge import AutoBridge
41+
>>> bridge = AutoBridge.from_hf_pretrained("WAN-3D-1.3B-v1")
42+
>>> provider = bridge.to_megatron_provider()
43+
"""
44+
45+
def provider_bridge(self, hf_pretrained: PreTrainedWAN) -> WanModelProvider:
46+
hf_config = hf_pretrained.config
47+
48+
cls = WanModelProvider
49+
50+
provider = cls(
51+
num_layers=hf_config.num_layers,
52+
hidden_size=hf_config.num_attention_heads * hf_config.attention_head_dim,
53+
kv_channels=hf_config.attention_head_dim,
54+
num_query_groups=hf_config.num_attention_heads,
55+
crossattn_emb_size=hf_config.num_attention_heads * hf_config.attention_head_dim,
56+
ffn_hidden_size=hf_config.ffn_dim,
57+
num_attention_heads=hf_config.num_attention_heads,
58+
in_channels=hf_config.in_channels,
59+
out_channels=hf_config.out_channels,
60+
text_dim=hf_config.text_dim,
61+
patch_spatial=hf_config.patch_size[1],
62+
patch_temporal=hf_config.patch_size[0],
63+
layernorm_epsilon=hf_config.eps,
64+
hidden_dropout=0,
65+
attention_dropout=0,
66+
use_cpu_initialization=True,
67+
freq_dim=hf_config.freq_dim,
68+
bf16=False,
69+
params_dtype=torch.float32,
70+
)
71+
72+
return provider
73+
74+
def mapping_registry(self) -> MegatronMappingRegistry:
75+
"""Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format.
76+
77+
Returns:
78+
MegatronMappingRegistry: Registry of parameter mappings
79+
"""
80+
# Dictionary maps HF parameter names -> Megatron parameter names
81+
# Supports wildcard (*) patterns for layer-specific parameters
82+
param_mappings = {
83+
"scale_shift_table": "head.modulation",
84+
"patch_embedding.weight": "patch_embedding.weight",
85+
"patch_embedding.bias": "patch_embedding.bias",
86+
"condition_embedder.time_embedder.linear_1.weight": "time_embedder.linear_1.weight",
87+
"condition_embedder.time_embedder.linear_1.bias": "time_embedder.linear_1.bias",
88+
"condition_embedder.time_embedder.linear_2.weight": "time_embedder.linear_2.weight",
89+
"condition_embedder.time_embedder.linear_2.bias": "time_embedder.linear_2.bias",
90+
"condition_embedder.time_proj.weight": "time_proj.weight",
91+
"condition_embedder.time_proj.bias": "time_proj.bias",
92+
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
93+
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
94+
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
95+
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
96+
"blocks.*.scale_shift_table": "decoder.layers.*.adaLN.modulation",
97+
"blocks.*.attn1.to_out.0.weight": "decoder.layers.*.full_self_attention.linear_proj.weight",
98+
"blocks.*.attn1.to_out.0.bias": "decoder.layers.*.full_self_attention.linear_proj.bias",
99+
"blocks.*.attn1.norm_q.weight": "decoder.layers.*.full_self_attention.q_layernorm.weight",
100+
"blocks.*.attn1.norm_k.weight": "decoder.layers.*.full_self_attention.k_layernorm.weight",
101+
"blocks.*.attn2.to_q.weight": "decoder.layers.*.cross_attention.linear_q.weight",
102+
"blocks.*.attn2.to_q.bias": "decoder.layers.*.cross_attention.linear_q.bias",
103+
"blocks.*.attn2.to_out.0.weight": "decoder.layers.*.cross_attention.linear_proj.weight",
104+
"blocks.*.attn2.to_out.0.bias": "decoder.layers.*.cross_attention.linear_proj.bias",
105+
"blocks.*.attn2.norm_q.weight": "decoder.layers.*.cross_attention.q_layernorm.weight",
106+
"blocks.*.attn2.norm_k.weight": "decoder.layers.*.cross_attention.k_layernorm.weight",
107+
"blocks.*.norm2.weight": "decoder.layers.*.norm3.weight",
108+
"blocks.*.norm2.bias": "decoder.layers.*.norm3.bias",
109+
"blocks.*.ffn.net.0.proj.weight": "decoder.layers.*.mlp.linear_fc1.weight",
110+
"blocks.*.ffn.net.0.proj.bias": "decoder.layers.*.mlp.linear_fc1.bias",
111+
"blocks.*.ffn.net.2.weight": "decoder.layers.*.mlp.linear_fc2.weight",
112+
"blocks.*.ffn.net.2.bias": "decoder.layers.*.mlp.linear_fc2.bias",
113+
"proj_out.weight": "head.head.weight",
114+
"proj_out.bias": "head.head.bias",
115+
}
116+
117+
# Custom WAN mapping to safely handle replicated params whose owning module
118+
# does not expose a top-level `.weight` (e.g., Head.modulation)
119+
class _ReplicatedByParamNameMapping(ReplicatedMapping):
120+
def hf_to_megatron(self, hf_weights, megatron_module):
121+
normalized_param = self._normalize_expert_param_name(self.megatron_param)
122+
_, target_param = get_module_and_param_from_name(megatron_module, normalized_param)
123+
124+
target_device = target_param.device
125+
target_dtype = target_param.dtype
126+
127+
hf_weights = hf_weights.to(device=target_device, dtype=target_dtype)
128+
if self.tp_size == 1:
129+
return hf_weights
130+
131+
if target_device.type == "cuda" and torch.cuda.is_available():
132+
if target_device.index != torch.cuda.current_device():
133+
hf_weights = hf_weights.to(torch.cuda.current_device())
134+
135+
if self.tp_rank > 0:
136+
hf_weights = torch.empty_like(hf_weights)
137+
138+
return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0)
139+
140+
mapping_list = []
141+
# Convert each dictionary entry to AutoMapping(hf_param, megatron_param)
142+
for hf_param, megatron_param in param_mappings.items():
143+
if hf_param in {"scale_shift_table", "blocks.*.scale_shift_table", "proj_out.weight", "proj_out.bias"}:
144+
# Use WAN-specific replicated mapping that resolves the exact param
145+
mapping_list.append(_ReplicatedByParamNameMapping(hf_param=hf_param, megatron_param=megatron_param))
146+
else:
147+
mapping_list.append(AutoMapping(hf_param=hf_param, megatron_param=megatron_param))
148+
149+
# Adding custom module types for AutoMapping
150+
AutoMapping.register_module_type("Linear", "replicated")
151+
AutoMapping.register_module_type("Conv3d", "replicated")
152+
AutoMapping.register_module_type("WanAdaLN", "replicated")
153+
AutoMapping.register_module_type("Head", "replicated")
154+
155+
# Add special mappings that require parameter concatenation/transformation
156+
mapping_list.extend(
157+
[
158+
# QKV: Combine separate Q, K, V matrices into single QKV matrix
159+
QKVMapping(
160+
q="blocks.*.attn1.to_q.weight",
161+
k="blocks.*.attn1.to_k.weight",
162+
v="blocks.*.attn1.to_v.weight",
163+
megatron_param="decoder.layers.*.full_self_attention.linear_qkv.weight",
164+
),
165+
# QKV bias: Combine separate Q, K, V bias into single QKV bias
166+
QKVMapping(
167+
q="blocks.*.attn1.to_q.bias",
168+
k="blocks.*.attn1.to_k.bias",
169+
v="blocks.*.attn1.to_v.bias",
170+
megatron_param="decoder.layers.*.full_self_attention.linear_qkv.bias",
171+
),
172+
# K, V: Combine separate K, V matrices into single KV matrix
173+
KVMapping(
174+
k="blocks.*.attn2.to_k.weight",
175+
v="blocks.*.attn2.to_v.weight",
176+
megatron_param="decoder.layers.*.cross_attention.linear_kv.weight",
177+
),
178+
# K, V bias: Combine separate K, V bias into single KV bias
179+
KVMapping(
180+
k="blocks.*.attn2.to_k.bias",
181+
v="blocks.*.attn2.to_v.bias",
182+
megatron_param="decoder.layers.*.cross_attention.linear_kv.bias",
183+
),
184+
]
185+
)
186+
187+
return MegatronMappingRegistry(*mapping_list)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import shutil
17+
from pathlib import Path
18+
from typing import Union
19+
20+
from diffusers import WanTransformer3DModel
21+
from megatron.bridge.models.hf_pretrained.base import PreTrainedBase
22+
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource, StateDict, StateSource
23+
from transformers import AutoConfig
24+
25+
26+
class WanSafeTensorsStateSource(SafeTensorsStateSource):
27+
"""
28+
WAN-specific state source that writes exported HF shards under 'transformer/'.
29+
"""
30+
31+
def save_generator(self, generator, output_path, strict: bool = True):
32+
# Ensure shards are written under transformer/
33+
target_dir = Path(output_path) / "transformer"
34+
return super().save_generator(generator, target_dir, strict=strict)
35+
36+
37+
class PreTrainedWAN(PreTrainedBase):
38+
"""
39+
Lightweight pretrained wrapper for Diffusers WAN models.
40+
41+
Provides access to WAN config and state through the common PreTrainedBase API
42+
so bridges can consume `.config` and `.state` uniformly.
43+
44+
NOTE: Due to Wan uses HF's Diffusers library, which has different checkpoint directory structure to HF's Transformer library,
45+
we need a wrapper to load the model weights and config from the correct directory (e.g., ./transformer).
46+
The diffusers's structure includes all components in the diffusion pipeline (VAE, text encoders, etc.).
47+
The actual transformer weights are stored in the ./transformer directory. Hence, we adjust the input and output
48+
path directory accordingly. We also need to override the save_artifacts method to save relevant correct configs
49+
files to the corresponding directory.
50+
"""
51+
52+
def __init__(self, model_name_or_path: Union[str, Path], **kwargs):
53+
self._model_name_or_path = str(model_name_or_path)
54+
super().__init__(**kwargs)
55+
56+
@property
57+
def model_name_or_path(self) -> str:
58+
return self._model_name_or_path
59+
60+
# Model loading is optional for conversion; implemented for completeness
61+
def _load_model(self) -> WanTransformer3DModel:
62+
return WanTransformer3DModel.from_pretrained(self.model_name_or_path)
63+
64+
# Config is required by the WAN bridge
65+
def _load_config(self) -> AutoConfig:
66+
# WanTransformer3DModel returns a config-like object with required fields
67+
68+
print(f"Loading config from {self.model_name_or_path}")
69+
70+
return WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer").config
71+
72+
@property
73+
def state(self) -> StateDict:
74+
"""
75+
WAN-specific StateDict that reads safetensors from the fixed 'transformer/' subfolder.
76+
"""
77+
if getattr(self, "_state_dict_accessor", None) is None:
78+
source: StateSource | None = None
79+
if hasattr(self, "_model") and self._model is not None:
80+
# If model is loaded, use its in-memory state_dict
81+
source = self.model.state_dict()
82+
else:
83+
# Always load from 'transformer/' subfolder for WAN
84+
source = WanSafeTensorsStateSource(Path(self.model_name_or_path) / "transformer")
85+
self._state_dict_accessor = StateDict(source)
86+
return self._state_dict_accessor
87+
88+
def save_artifacts(self, save_directory: Union[str, Path]):
89+
"""
90+
Save WAN artifacts (currently config) alongside exported weights.
91+
Writes transformer/config.json into the destination.
92+
"""
93+
save_path = Path(save_directory)
94+
save_path.mkdir(parents=True, exist_ok=True)
95+
96+
# Ensure transformer subdir exists at destination
97+
dest_transformer = save_path / "transformer"
98+
dest_transformer.mkdir(parents=True, exist_ok=True)
99+
100+
# 1) If source has a config.json under transformer/, copy it
101+
src_config = Path(self.model_name_or_path) / "transformer" / "config.json"
102+
src_index = Path(self.model_name_or_path) / "transformer" / "diffusion_pytorch_model.safetensors.index.json"
103+
if src_config.exists():
104+
shutil.copyfile(src_config, dest_transformer / "config.json")
105+
if src_index.exists():
106+
shutil.copyfile(src_index, dest_transformer / "diffusion_pytorch_model.safetensors.index.json")
107+
return
108+
109+
# 2) Otherwise, try to export config from the HF model instance
110+
try:
111+
model = WanTransformer3DModel.from_pretrained(self.model_name_or_path, subfolder="transformer")
112+
cfg = getattr(model, "config", None)
113+
if cfg is not None:
114+
# Prefer to_dict if available
115+
cfg_dict = cfg.to_dict() if hasattr(cfg, "to_dict") else dict(cfg)
116+
with open(dest_transformer / "config.json", "w") as f:
117+
json.dump(cfg_dict, f, indent=2)
118+
except Exception:
119+
# Best-effort: if config cannot be produced, leave only weights
120+
pass

dfm/src/megatron/model/wan/wan_layer_spec.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def get_wan_block_with_transformer_engine_spec() -> ModuleSpec:
297297
module=MLP,
298298
submodules=MLPSubmodules(
299299
linear_fc1=TEColumnParallelLinear,
300-
# by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh')
301300
linear_fc2=TERowParallelLinear,
302301
),
303302
),

dfm/src/megatron/model/wan/wan_provider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
import logging
1717
from dataclasses import dataclass
18+
from typing import Callable
1819

1920
import torch
21+
import torch.nn.functional as F
2022
from megatron.bridge.models.model_provider import ModelProviderMixin
2123
from megatron.bridge.models.transformer_config import TransformerConfig
2224
from megatron.core import parallel_state
@@ -44,6 +46,7 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]):
4446
layernorm_across_heads: bool = True
4547
add_qkv_bias: bool = True
4648
rotary_interleaved: bool = True
49+
activation_func: Callable = F.gelu
4750
hidden_dropout: float = 0
4851
attention_dropout: float = 0
4952
fp16_lm_cross_entropy: bool = False

0 commit comments

Comments
 (0)