Skip to content

Commit 63cf465

Browse files
committed
fix for gptoss
Signed-off-by: h-guo18 <[email protected]>
1 parent cb44c55 commit 63cf465

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,18 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
4848
raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")
4949

5050

51-
def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
51+
def spec_opt_only(model: nn.Module):
52+
"""Check if the model have only speculative decoding optimization."""
53+
opt_modes = getattr(model, "_modelopt_state", None)
54+
return (
55+
isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle"
56+
)
57+
58+
59+
def export_spec_ckpt_state_dict(model: nn.Module):
5260
"""Only return the state dict of the draft model in official format and ignore the base model."""
5361
# check the model has only speculative decoding
54-
opt_modes = getattr(model, "_modelopt_state", None)
55-
if (
56-
not isinstance(opt_modes, (list, tuple))
57-
or len(opt_modes) != 1
58-
or opt_modes[0][0] != "eagle"
59-
):
60-
# if there's other opts, return as is
61-
return post_state_dict
62+
assert spec_opt_only(model), "Not purely eagle model."
6263

6364
# Check if the state dict keys match
6465
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
@@ -80,16 +81,9 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
8081
return export_state_dict
8182

8283

83-
def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
84+
def export_spec_ckpt_config(model: nn.Module):
8485
"""Return the config of draft model in official format."""
85-
opt_modes = getattr(model, "_modelopt_state", None)
86-
if (
87-
not isinstance(opt_modes, (list, tuple))
88-
or len(opt_modes) != 1
89-
or opt_modes[0][0] != "eagle"
90-
):
91-
# return as is
92-
return config_data
86+
assert spec_opt_only(model), "Not purely eagle model."
9387

9488
# This is the config keys in official checkpoint.
9589
template_config = {

modelopt/torch/export/unified_export_hf.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import torch
2828
import torch.nn as nn
29+
from safetensors.torch import save_file
2930

3031
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3132
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
@@ -53,7 +54,7 @@
5354
QUANTIZATION_W4A8_AWQ,
5455
QUANTIZATION_W4A8_NVFP4_FP8,
5556
)
56-
from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding
57+
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
5758
from .quant_utils import (
5859
fuse_prequant_layernorm,
5960
get_activation_scaling_factor,
@@ -507,18 +508,24 @@ def export_hf_checkpoint(
507508
"""
508509
export_dir = Path(export_dir)
509510
export_dir.mkdir(parents=True, exist_ok=True)
511+
512+
# Early exit for speculative decoding models
513+
# We do this since some spec models get error in convert_hf_quant_config_format
514+
if spec_opt_only(model):
515+
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
516+
with open(f"{export_dir}/config.json", "w") as file:
517+
json.dump(export_spec_ckpt_config(model), file, indent=4)
518+
return
519+
510520
try:
511521
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
512522

513-
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
514523
# Save hf_quant_config.json for backward compatibility
515524
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
516525
json.dump(hf_quant_config, file, indent=4)
517526

518527
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
519528

520-
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
521-
522529
# Save model
523530
model.save_pretrained(
524531
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
@@ -532,8 +539,6 @@ def export_hf_checkpoint(
532539

533540
config_data["quantization_config"] = hf_quant_config
534541

535-
config_data = set_config_if_spec_decoding(model, config_data)
536-
537542
with open(original_config, "w") as file:
538543
json.dump(config_data, file, indent=4)
539544

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,5 @@
4747
"use_mtp_layernorm": False,
4848
"parallel_draft_step": 1,
4949
"has_lm_head": False,
50+
"head_dim": 64,
5051
}

0 commit comments

Comments
 (0)