diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index aeea25adb..d25d54bbb 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -115,7 +115,7 @@ After training draft model, we can evaluate the saved modelopt checkpoint on MT- python ar_validate.py --model_path $OUTPUT_DIR ``` -Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. +**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks. ## Export diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index fe044828a..9f89cb269 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -48,17 +48,18 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict): raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}") -def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict): +def spec_opt_only(model: nn.Module): + """Check if the model have only speculative decoding optimization.""" + opt_modes = getattr(model, "_modelopt_state", None) + return ( + isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle" + ) + + +def export_spec_ckpt_state_dict(model: nn.Module): """Only return the state dict of the draft model in official format and ignore the base model.""" # check the model has only speculative decoding - opt_modes = getattr(model, "_modelopt_state", None) - if ( - not isinstance(opt_modes, (list, tuple)) - or len(opt_modes) != 1 - or opt_modes[0][0] != "eagle" - ): - # if there's other opts, return as is - return post_state_dict + assert spec_opt_only(model), "Not purely eagle model." # Check if the state dict keys match _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"]) @@ -80,16 +81,9 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict): return export_state_dict -def set_config_if_spec_decoding(model: nn.Module, config_data: dict): +def export_spec_ckpt_config(model: nn.Module): """Return the config of draft model in official format.""" - opt_modes = getattr(model, "_modelopt_state", None) - if ( - not isinstance(opt_modes, (list, tuple)) - or len(opt_modes) != 1 - or opt_modes[0][0] != "eagle" - ): - # return as is - return config_data + assert spec_opt_only(model), "Not purely eagle model." # This is the config keys in official checkpoint. template_config = { diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f514e660d..21c2f0e24 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn +from safetensors.torch import save_file from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer @@ -53,7 +54,7 @@ QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ) -from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding +from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, get_activation_scaling_factor, @@ -507,18 +508,24 @@ def export_hf_checkpoint( """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) + + # Early exit for speculative decoding models + # We do this since some spec models get error in convert_hf_quant_config_format + if spec_opt_only(model): + save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") + with open(f"{export_dir}/config.json", "w") as file: + json.dump(export_spec_ckpt_config(model), file, indent=4) + return + try: post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) - # NOTE: (hg) Should we save hf_quant_config when there's no quantization applied? # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) hf_quant_config = convert_hf_quant_config_format(hf_quant_config) - post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict) - # Save model model.save_pretrained( export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state @@ -532,8 +539,6 @@ def export_hf_checkpoint( config_data["quantization_config"] = hf_quant_config - config_data = set_config_if_spec_decoding(model, config_data) - with open(original_config, "w") as file: json.dump(config_data, file, indent=4) diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index f8c69b2ff..1a7f1fddc 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -47,4 +47,5 @@ "use_mtp_layernorm": False, "parallel_draft_step": 1, "has_lm_head": False, + "head_dim": 128, }