|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +Checkpoint conversion mappings for loading HuggingFace checkpoints. |
| 17 | +
|
| 18 | +This module provides conversion mappings for transforming checkpoint keys and tensors |
| 19 | +when loading models. It primarily uses the transformers library's conversion_mapping |
| 20 | +module which handles both key renaming and tensor operations (merging/splitting). |
| 21 | +
|
| 22 | +For MoE models, the conversion handles: |
| 23 | +- Key renaming from checkpoint format (e.g., block_sparse_moe.experts.X.w1) to |
| 24 | + model format (e.g., mlp.experts.gate_up_proj) |
| 25 | +- Tensor merging for grouped expert formats (individual experts -> single 3D tensor) |
| 26 | +
|
| 27 | +The primary entry points are: |
| 28 | +- `get_checkpoint_conversion_mapping(model_type)`: Get conversion rules for a model type |
| 29 | +- `get_model_conversion_mapping(model, ...)`: Get all conversion rules for a model instance |
| 30 | +- `requires_tensor_merging(model_type)`: Check if model needs tensor operations |
| 31 | +""" |
| 32 | + |
| 33 | +from typing import TYPE_CHECKING, Optional |
| 34 | + |
| 35 | +if TYPE_CHECKING: |
| 36 | + from torch import nn |
| 37 | + |
| 38 | + |
| 39 | +# Try to import from transformers - this is the preferred source |
| 40 | +_TRANSFORMERS_AVAILABLE = False |
| 41 | +try: |
| 42 | + from transformers.conversion_mapping import ( |
| 43 | + get_checkpoint_conversion_mapping as _transformers_get_checkpoint_conversion_mapping, |
| 44 | + get_model_conversion_mapping as _transformers_get_model_conversion_mapping, |
| 45 | + ) |
| 46 | + from transformers.core_model_loading import WeightConverter, WeightRenaming |
| 47 | + |
| 48 | + _TRANSFORMERS_AVAILABLE = True |
| 49 | +except ImportError: |
| 50 | + # Transformers not available or doesn't have conversion_mapping |
| 51 | + WeightConverter = None |
| 52 | + WeightRenaming = None |
| 53 | + |
| 54 | + |
| 55 | +# Model types that require tensor merging (individual experts -> grouped experts) |
| 56 | +# For these models, simple key renaming is not sufficient - they need WeightConverter |
| 57 | +# operations to merge individual expert weights into grouped format |
| 58 | +MODELS_REQUIRING_TENSOR_MERGING = { |
| 59 | + "mixtral", |
| 60 | + "minimax", |
| 61 | + "phimoe", |
| 62 | + "qwen2_moe", |
| 63 | + "qwen3_moe", |
| 64 | + "deepseek_v2", |
| 65 | + "deepseek_v3", |
| 66 | + "jamba", |
| 67 | + "olmoe", |
| 68 | + "lfm2_moe", |
| 69 | + "dots1", |
| 70 | + "ernie4_5_moe", |
| 71 | + "glm4_moe", |
| 72 | + "glm4v_moe", |
| 73 | + "longcat_flash", |
| 74 | + "qwen3_omni_moe", |
| 75 | + "qwen3_next", |
| 76 | + "qwen3_vl_moe", |
| 77 | + "hunyuan_v1_moe", |
| 78 | + "flex_olmo", |
| 79 | +} |
| 80 | + |
| 81 | + |
| 82 | +def requires_tensor_merging(model_type: str) -> bool: |
| 83 | + """ |
| 84 | + Check if a model type requires tensor merging during checkpoint loading. |
| 85 | +
|
| 86 | + Some MoE models store expert weights in grouped format (single 3D tensor for all experts) |
| 87 | + but checkpoints store individual expert weights. These models require tensor merging |
| 88 | + that cannot be done via simple key renaming. |
| 89 | +
|
| 90 | + Args: |
| 91 | + model_type: The model type string from config.model_type |
| 92 | +
|
| 93 | + Returns: |
| 94 | + True if the model type requires tensor merging during loading. |
| 95 | + """ |
| 96 | + return model_type in MODELS_REQUIRING_TENSOR_MERGING |
| 97 | + |
| 98 | + |
| 99 | +def get_checkpoint_conversion_mapping(model_type: str) -> Optional[list]: |
| 100 | + """ |
| 101 | + Get the checkpoint conversion mapping for a given model type. |
| 102 | +
|
| 103 | + This returns a list of WeightConverter and/or WeightRenaming objects from |
| 104 | + transformers that define how to convert checkpoint keys and tensors to |
| 105 | + model state dict format. |
| 106 | +
|
| 107 | + Args: |
| 108 | + model_type: The model type string (e.g., "mixtral", "qwen2_moe", "phimoe") |
| 109 | +
|
| 110 | + Returns: |
| 111 | + A list of WeightConverter/WeightRenaming objects defining the conversion, |
| 112 | + or None if no conversion mapping is defined for this model type. |
| 113 | +
|
| 114 | + Example: |
| 115 | + >>> mapping = get_checkpoint_conversion_mapping("mixtral") |
| 116 | + >>> # Returns list with WeightRenaming for gate and WeightConverter |
| 117 | + >>> # for merging individual expert weights into grouped format |
| 118 | + """ |
| 119 | + if not _TRANSFORMERS_AVAILABLE: |
| 120 | + return None |
| 121 | + return _transformers_get_checkpoint_conversion_mapping(model_type) |
| 122 | + |
| 123 | + |
| 124 | +def get_model_conversion_mapping( |
| 125 | + model: "nn.Module", |
| 126 | + key_mapping: Optional[dict[str, str]] = None, |
| 127 | + hf_quantizer: Optional[object] = None, |
| 128 | + add_legacy: bool = True, |
| 129 | +) -> list: |
| 130 | + """ |
| 131 | + Get all weight conversion mappings for a model instance. |
| 132 | +
|
| 133 | + This is the main entry point for getting conversion rules. It combines: |
| 134 | + 1. Custom key_mapping if provided |
| 135 | + 2. Model's _checkpoint_conversion_mapping attribute (for VLMs) |
| 136 | + 3. Model-type specific conversions (MoE merging, etc.) |
| 137 | + 4. Legacy conversions (LayerNorm.gamma -> LayerNorm.weight, etc.) |
| 138 | + 5. Quantizer-specific conversions if provided |
| 139 | +
|
| 140 | + Args: |
| 141 | + model: The model instance to get conversions for |
| 142 | + key_mapping: Optional custom key mapping (source -> target patterns) |
| 143 | + hf_quantizer: Optional HuggingFace quantizer with additional conversions |
| 144 | + add_legacy: Whether to include legacy LayerNorm conversions (default True) |
| 145 | +
|
| 146 | + Returns: |
| 147 | + List of WeightConverter/WeightRenaming objects defining all conversions. |
| 148 | + Returns empty list if transformers is not available. |
| 149 | +
|
| 150 | + Example: |
| 151 | + >>> from transformers import AutoModelForCausalLM |
| 152 | + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B") |
| 153 | + >>> conversions = get_model_conversion_mapping(model) |
| 154 | + >>> # Use conversions to transform checkpoint state dict |
| 155 | + """ |
| 156 | + if not _TRANSFORMERS_AVAILABLE: |
| 157 | + return [] |
| 158 | + return _transformers_get_model_conversion_mapping( |
| 159 | + model, |
| 160 | + key_mapping=key_mapping, |
| 161 | + hf_quantizer=hf_quantizer, |
| 162 | + add_legacy=add_legacy, |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | +def get_combined_key_mapping( |
| 167 | + model_type: str, |
| 168 | + model_key_mapping: Optional[dict[str, str]] = None, |
| 169 | +) -> Optional[dict[str, str]]: |
| 170 | + """ |
| 171 | + Get combined key mapping for simple regex-based key renaming. |
| 172 | +
|
| 173 | + This is a simpler alternative to get_model_conversion_mapping that only |
| 174 | + handles key renaming (not tensor operations). Useful when you just need |
| 175 | + to rename keys without merging tensors. |
| 176 | +
|
| 177 | + Note: For MoE models that require tensor merging, use get_model_conversion_mapping |
| 178 | + instead, which returns WeightConverter objects that handle both renaming and merging. |
| 179 | +
|
| 180 | + Args: |
| 181 | + model_type: The model type string from config.model_type |
| 182 | + model_key_mapping: Optional key mapping from the model's |
| 183 | + `_checkpoint_conversion_mapping` attribute |
| 184 | +
|
| 185 | + Returns: |
| 186 | + Combined key mapping dictionary (regex pattern -> replacement), |
| 187 | + or None if no mappings are defined. |
| 188 | + """ |
| 189 | + result = {} |
| 190 | + |
| 191 | + # First add model-specific key mapping (takes precedence) |
| 192 | + if model_key_mapping: |
| 193 | + result.update(model_key_mapping) |
| 194 | + |
| 195 | + # Try to get conversion mapping from transformers and extract simple renamings |
| 196 | + if _TRANSFORMERS_AVAILABLE: |
| 197 | + conversions = get_checkpoint_conversion_mapping(model_type) |
| 198 | + if conversions: |
| 199 | + for conv in conversions: |
| 200 | + # Only extract simple WeightRenaming, not WeightConverter |
| 201 | + if WeightRenaming is not None and isinstance(conv, WeightRenaming): |
| 202 | + # WeightRenaming stores patterns as source_patterns and target_patterns (as lists) |
| 203 | + sources = getattr(conv, "source_patterns", None) |
| 204 | + targets = getattr(conv, "target_patterns", None) |
| 205 | + if sources and targets: |
| 206 | + # Handle both list and string formats |
| 207 | + if isinstance(sources, str): |
| 208 | + sources = [sources] |
| 209 | + if isinstance(targets, str): |
| 210 | + targets = [targets] |
| 211 | + # Add each source->target pair |
| 212 | + for source, target in zip(sources, targets): |
| 213 | + if source not in result: |
| 214 | + result[source] = target |
| 215 | + |
| 216 | + return result if result else None |
| 217 | + |
| 218 | + |
| 219 | +def is_transformers_conversion_available() -> bool: |
| 220 | + """ |
| 221 | + Check if transformers conversion mapping is available. |
| 222 | +
|
| 223 | + Returns: |
| 224 | + True if transformers library with conversion_mapping module is available. |
| 225 | + """ |
| 226 | + return _TRANSFORMERS_AVAILABLE |
0 commit comments