Skip to content

Commit f9066c1

Browse files
committed
fix router type
Signed-off-by: jenchen13 <[email protected]>
1 parent 5f7158f commit f9066c1

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ def __init__(
8383
self._hf_config = transformers.AutoConfig.from_pretrained(
8484
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
8585
)
86-
self.moe_router_dtype = moe_router_dtype
86+
self.moe_router_dtype = None
87+
if moe_router_dtype == "fp32":
88+
self.moe_router_dtype = torch.float32
89+
elif moe_router_dtype == "fp64":
90+
self.moe_router_dtype = torch.float64
8791

8892
pretrained_model_path = Path(pretrained_model_name_or_path)
8993
if not pretrained_model_path.is_dir():
@@ -145,6 +149,8 @@ def _name_remapping(
145149
parallel_config: ParallelConfig | None = None,
146150
dtype: torch.dtype | None = None,
147151
):
152+
if dtype is None:
153+
dtype = self.dtype
148154
if isinstance(module, torch.Tensor):
149155
tensor = self._get_safetensor(prefix, parallel_config=parallel_config)
150156
module.data.copy_(tensor)
@@ -197,7 +203,7 @@ def _name_remapping(
197203
tensor = self._get_safetensor(
198204
prefix + source_key, parallel_config=parallel_config
199205
)
200-
state_dict[key] = tensor.to(dtype=self.dtype).to(device=val.device)
206+
state_dict[key] = tensor.to(dtype=dtype).to(device=val.device)
201207

202208
module.load_state_dict(state_dict)
203209

modelopt/torch/export/unified_export_megatron.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor:
109109

110110
def get_quantized_state(
111111
module: torch.nn.Module,
112-
dtype: torch.dtype = torch.float16,
112+
dtype: torch.dtype = torch.bfloat16,
113113
) -> tuple[dict[str, torch.Tensor], str, int]:
114114
"""Return a state_dict, quantization format, and block_size of the module.
115115
@@ -197,7 +197,12 @@ def __init__(
197197
self._hf_config = transformers.AutoConfig.from_pretrained(
198198
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
199199
)
200-
self.moe_router_dtype = moe_router_dtype
200+
self.moe_router_dtype = None
201+
if moe_router_dtype == "fp32":
202+
self.moe_router_dtype = torch.float32
203+
elif moe_router_dtype == "fp64":
204+
self.moe_router_dtype = torch.float64
205+
201206
# If multimodal, extra the text_config
202207
self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config)
203208

@@ -1142,7 +1147,7 @@ def export_mcore_gpt_to_hf(
11421147
model: torch.nn.Module,
11431148
pretrained_model_name_or_path: str | os.PathLike | None = None,
11441149
export_extra_modules: bool = False,
1145-
dtype: torch.dtype = torch.float16,
1150+
dtype: torch.dtype = torch.bfloat16,
11461151
export_dir: Path | str = tempfile.gettempdir(),
11471152
moe_router_dtype: torch.dtype | None = None,
11481153
):
@@ -1169,7 +1174,7 @@ def import_mcore_gpt_from_hf(
11691174
model: torch.nn.Module,
11701175
pretrained_model_path: str,
11711176
workspace_dir: str | None = None,
1172-
dtype: torch.dtype = torch.float16,
1177+
dtype: torch.dtype = torch.bfloat16,
11731178
moe_router_dtype: torch.dtype | None = None,
11741179
):
11751180
"""Import GPTModel state_dict from supported HuggingFace pretrained model path.

0 commit comments

Comments
 (0)