Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ After training draft model, we can evaluate the saved modelopt checkpoint on MT-
python ar_validate.py --model_path $OUTPUT_DIR
```

Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks.

## Export

Expand Down
30 changes: 12 additions & 18 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,18 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")


def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
def spec_opt_only(model: nn.Module):
"""Check if the model have only speculative decoding optimization."""
opt_modes = getattr(model, "_modelopt_state", None)
return (
isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle"
)


def export_spec_ckpt_state_dict(model: nn.Module):
"""Only return the state dict of the draft model in official format and ignore the base model."""
# check the model has only speculative decoding
opt_modes = getattr(model, "_modelopt_state", None)
if (
not isinstance(opt_modes, (list, tuple))
or len(opt_modes) != 1
or opt_modes[0][0] != "eagle"
):
# if there's other opts, return as is
return post_state_dict
assert spec_opt_only(model), "Not purely eagle model."

# Check if the state dict keys match
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
Expand All @@ -80,16 +81,9 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
return export_state_dict


def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
def export_spec_ckpt_config(model: nn.Module):
"""Return the config of draft model in official format."""
opt_modes = getattr(model, "_modelopt_state", None)
if (
not isinstance(opt_modes, (list, tuple))
or len(opt_modes) != 1
or opt_modes[0][0] != "eagle"
):
# return as is
return config_data
assert spec_opt_only(model), "Not purely eagle model."

# This is the config keys in official checkpoint.
template_config = {
Expand Down
17 changes: 11 additions & 6 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch
import torch.nn as nn
from safetensors.torch import save_file

from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
Expand Down Expand Up @@ -53,7 +54,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
get_activation_scaling_factor,
Expand Down Expand Up @@ -507,18 +508,24 @@ def export_hf_checkpoint(
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)

# Early exit for speculative decoding models
# We do this since some spec models get error in convert_hf_quant_config_format
if spec_opt_only(model):
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
with open(f"{export_dir}/config.json", "w") as file:
json.dump(export_spec_ckpt_config(model), file, indent=4)
return

try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)

# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

hf_quant_config = convert_hf_quant_config_format(hf_quant_config)

post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)

# Save model
model.save_pretrained(
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
Expand All @@ -532,8 +539,6 @@ def export_hf_checkpoint(

config_data["quantization_config"] = hf_quant_config

config_data = set_config_if_spec_decoding(model, config_data)

with open(original_config, "w") as file:
json.dump(config_data, file, indent=4)

Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"has_lm_head": False,
"head_dim": 128,
}