Skip to content

Commit 027a206

Browse files
committed
add file
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 008a7a9 commit 027a206

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Checkpoint conversion mappings for loading HuggingFace checkpoints.
17+
18+
This module provides conversion mappings for transforming checkpoint keys and tensors
19+
when loading models. It primarily uses the transformers library's conversion_mapping
20+
module which handles both key renaming and tensor operations (merging/splitting).
21+
22+
For MoE models, the conversion handles:
23+
- Key renaming from checkpoint format (e.g., block_sparse_moe.experts.X.w1) to
24+
model format (e.g., mlp.experts.gate_up_proj)
25+
- Tensor merging for grouped expert formats (individual experts -> single 3D tensor)
26+
27+
The primary entry points are:
28+
- `get_checkpoint_conversion_mapping(model_type)`: Get conversion rules for a model type
29+
- `get_model_conversion_mapping(model, ...)`: Get all conversion rules for a model instance
30+
- `requires_tensor_merging(model_type)`: Check if model needs tensor operations
31+
"""
32+
33+
from typing import TYPE_CHECKING, Optional
34+
35+
if TYPE_CHECKING:
36+
from torch import nn
37+
38+
39+
# Try to import from transformers - this is the preferred source
40+
_TRANSFORMERS_AVAILABLE = False
41+
try:
42+
from transformers.conversion_mapping import (
43+
get_checkpoint_conversion_mapping as _transformers_get_checkpoint_conversion_mapping,
44+
get_model_conversion_mapping as _transformers_get_model_conversion_mapping,
45+
)
46+
from transformers.core_model_loading import WeightConverter, WeightRenaming
47+
48+
_TRANSFORMERS_AVAILABLE = True
49+
except ImportError:
50+
# Transformers not available or doesn't have conversion_mapping
51+
WeightConverter = None
52+
WeightRenaming = None
53+
54+
55+
# Model types that require tensor merging (individual experts -> grouped experts)
56+
# For these models, simple key renaming is not sufficient - they need WeightConverter
57+
# operations to merge individual expert weights into grouped format
58+
MODELS_REQUIRING_TENSOR_MERGING = {
59+
"mixtral",
60+
"minimax",
61+
"phimoe",
62+
"qwen2_moe",
63+
"qwen3_moe",
64+
"deepseek_v2",
65+
"deepseek_v3",
66+
"jamba",
67+
"olmoe",
68+
"lfm2_moe",
69+
"dots1",
70+
"ernie4_5_moe",
71+
"glm4_moe",
72+
"glm4v_moe",
73+
"longcat_flash",
74+
"qwen3_omni_moe",
75+
"qwen3_next",
76+
"qwen3_vl_moe",
77+
"hunyuan_v1_moe",
78+
"flex_olmo",
79+
}
80+
81+
82+
def requires_tensor_merging(model_type: str) -> bool:
83+
"""
84+
Check if a model type requires tensor merging during checkpoint loading.
85+
86+
Some MoE models store expert weights in grouped format (single 3D tensor for all experts)
87+
but checkpoints store individual expert weights. These models require tensor merging
88+
that cannot be done via simple key renaming.
89+
90+
Args:
91+
model_type: The model type string from config.model_type
92+
93+
Returns:
94+
True if the model type requires tensor merging during loading.
95+
"""
96+
return model_type in MODELS_REQUIRING_TENSOR_MERGING
97+
98+
99+
def get_checkpoint_conversion_mapping(model_type: str) -> Optional[list]:
100+
"""
101+
Get the checkpoint conversion mapping for a given model type.
102+
103+
This returns a list of WeightConverter and/or WeightRenaming objects from
104+
transformers that define how to convert checkpoint keys and tensors to
105+
model state dict format.
106+
107+
Args:
108+
model_type: The model type string (e.g., "mixtral", "qwen2_moe", "phimoe")
109+
110+
Returns:
111+
A list of WeightConverter/WeightRenaming objects defining the conversion,
112+
or None if no conversion mapping is defined for this model type.
113+
114+
Example:
115+
>>> mapping = get_checkpoint_conversion_mapping("mixtral")
116+
>>> # Returns list with WeightRenaming for gate and WeightConverter
117+
>>> # for merging individual expert weights into grouped format
118+
"""
119+
if not _TRANSFORMERS_AVAILABLE:
120+
return None
121+
return _transformers_get_checkpoint_conversion_mapping(model_type)
122+
123+
124+
def get_model_conversion_mapping(
125+
model: "nn.Module",
126+
key_mapping: Optional[dict[str, str]] = None,
127+
hf_quantizer: Optional[object] = None,
128+
add_legacy: bool = True,
129+
) -> list:
130+
"""
131+
Get all weight conversion mappings for a model instance.
132+
133+
This is the main entry point for getting conversion rules. It combines:
134+
1. Custom key_mapping if provided
135+
2. Model's _checkpoint_conversion_mapping attribute (for VLMs)
136+
3. Model-type specific conversions (MoE merging, etc.)
137+
4. Legacy conversions (LayerNorm.gamma -> LayerNorm.weight, etc.)
138+
5. Quantizer-specific conversions if provided
139+
140+
Args:
141+
model: The model instance to get conversions for
142+
key_mapping: Optional custom key mapping (source -> target patterns)
143+
hf_quantizer: Optional HuggingFace quantizer with additional conversions
144+
add_legacy: Whether to include legacy LayerNorm conversions (default True)
145+
146+
Returns:
147+
List of WeightConverter/WeightRenaming objects defining all conversions.
148+
Returns empty list if transformers is not available.
149+
150+
Example:
151+
>>> from transformers import AutoModelForCausalLM
152+
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B")
153+
>>> conversions = get_model_conversion_mapping(model)
154+
>>> # Use conversions to transform checkpoint state dict
155+
"""
156+
if not _TRANSFORMERS_AVAILABLE:
157+
return []
158+
return _transformers_get_model_conversion_mapping(
159+
model,
160+
key_mapping=key_mapping,
161+
hf_quantizer=hf_quantizer,
162+
add_legacy=add_legacy,
163+
)
164+
165+
166+
def get_combined_key_mapping(
167+
model_type: str,
168+
model_key_mapping: Optional[dict[str, str]] = None,
169+
) -> Optional[dict[str, str]]:
170+
"""
171+
Get combined key mapping for simple regex-based key renaming.
172+
173+
This is a simpler alternative to get_model_conversion_mapping that only
174+
handles key renaming (not tensor operations). Useful when you just need
175+
to rename keys without merging tensors.
176+
177+
Note: For MoE models that require tensor merging, use get_model_conversion_mapping
178+
instead, which returns WeightConverter objects that handle both renaming and merging.
179+
180+
Args:
181+
model_type: The model type string from config.model_type
182+
model_key_mapping: Optional key mapping from the model's
183+
`_checkpoint_conversion_mapping` attribute
184+
185+
Returns:
186+
Combined key mapping dictionary (regex pattern -> replacement),
187+
or None if no mappings are defined.
188+
"""
189+
result = {}
190+
191+
# First add model-specific key mapping (takes precedence)
192+
if model_key_mapping:
193+
result.update(model_key_mapping)
194+
195+
# Try to get conversion mapping from transformers and extract simple renamings
196+
if _TRANSFORMERS_AVAILABLE:
197+
conversions = get_checkpoint_conversion_mapping(model_type)
198+
if conversions:
199+
for conv in conversions:
200+
# Only extract simple WeightRenaming, not WeightConverter
201+
if WeightRenaming is not None and isinstance(conv, WeightRenaming):
202+
# WeightRenaming stores patterns as source_patterns and target_patterns (as lists)
203+
sources = getattr(conv, "source_patterns", None)
204+
targets = getattr(conv, "target_patterns", None)
205+
if sources and targets:
206+
# Handle both list and string formats
207+
if isinstance(sources, str):
208+
sources = [sources]
209+
if isinstance(targets, str):
210+
targets = [targets]
211+
# Add each source->target pair
212+
for source, target in zip(sources, targets):
213+
if source not in result:
214+
result[source] = target
215+
216+
return result if result else None
217+
218+
219+
def is_transformers_conversion_available() -> bool:
220+
"""
221+
Check if transformers conversion mapping is available.
222+
223+
Returns:
224+
True if transformers library with conversion_mapping module is available.
225+
"""
226+
return _TRANSFORMERS_AVAILABLE

0 commit comments

Comments
 (0)