Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ After training draft model, we can evaluate the saved modelopt checkpoint on MT-
python ar_validate.py --model_path $OUTPUT_DIR
```

Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks.

## Export

Expand Down
30 changes: 12 additions & 18 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,18 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")


def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
def spec_opt_only(model: nn.Module):
"""Check if the model have only speculative decoding optimization."""
opt_modes = getattr(model, "_modelopt_state", None)
return (
isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle"
)


def export_spec_ckpt_state_dict(model: nn.Module):
"""Only return the state dict of the draft model in official format and ignore the base model."""
# check the model has only speculative decoding
opt_modes = getattr(model, "_modelopt_state", None)
if (
not isinstance(opt_modes, (list, tuple))
or len(opt_modes) != 1
or opt_modes[0][0] != "eagle"
):
# if there's other opts, return as is
return post_state_dict
assert spec_opt_only(model), "Not purely eagle model."

# Check if the state dict keys match
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
Expand All @@ -80,16 +81,9 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
return export_state_dict


def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
def export_spec_ckpt_config(model: nn.Module):
"""Return the config of draft model in official format."""
opt_modes = getattr(model, "_modelopt_state", None)
if (
not isinstance(opt_modes, (list, tuple))
or len(opt_modes) != 1
or opt_modes[0][0] != "eagle"
):
# return as is
return config_data
assert spec_opt_only(model), "Not purely eagle model."

# This is the config keys in official checkpoint.
template_config = {
Expand Down
17 changes: 11 additions & 6 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch
import torch.nn as nn
from safetensors.torch import save_file

from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
Expand Down Expand Up @@ -53,7 +54,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
get_activation_scaling_factor,
Expand Down Expand Up @@ -507,18 +508,24 @@ def export_hf_checkpoint(
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)

# Early exit for speculative decoding models
# We do this since some spec models get error in convert_hf_quant_config_format
if spec_opt_only(model):
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
with open(f"{export_dir}/config.json", "w") as file:
json.dump(export_spec_ckpt_config(model), file, indent=4)
return

try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)

# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

hf_quant_config = convert_hf_quant_config_format(hf_quant_config)

post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)

# Save model
model.save_pretrained(
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
Expand All @@ -532,8 +539,6 @@ def export_hf_checkpoint(

config_data["quantization_config"] = hf_quant_config

config_data = set_config_if_spec_decoding(model, config_data)

with open(original_config, "w") as file:
json.dump(config_data, file, indent=4)

Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"has_lm_head": False,
"head_dim": 64,
}