From fdbde136fa94764b19aa54048478fdeec2b2814f Mon Sep 17 00:00:00 2001
From: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Date: Fri, 5 Sep 2025 18:21:52 +0000
Subject: [PATCH 1/7] feat: update eagle3 example; add export
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
---
examples/speculative_decoding/README.md | 276 +++++++++---------
examples/speculative_decoding/ar_validate.py | 8 +-
.../calibrate_draft_vocab.py | 15 +-
.../speculative_decoding/eagle_config.json | 9 +-
.../export_hf_checkpoint.py | 41 +++
examples/speculative_decoding/main.py | 11 +
.../speculative_decoding/server_generate.py | 2 +-
.../train_eagle3_and_export.sh | 72 +++++
modelopt/torch/export/plugins/__init__.py | 2 +
.../torch/export/plugins/hf_spec_export.py | 151 ++++++++++
modelopt/torch/export/unified_export_hf.py | 24 +-
.../torch/speculative/plugins/transformers.py | 99 ++++---
12 files changed, 519 insertions(+), 191 deletions(-)
create mode 100644 examples/speculative_decoding/export_hf_checkpoint.py
create mode 100644 examples/speculative_decoding/train_eagle3_and_export.sh
create mode 100644 modelopt/torch/export/plugins/hf_spec_export.py
mode change 100755 => 100644 modelopt/torch/export/unified_export_hf.py
diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md
index 187bda560..910077a58 100644
--- a/examples/speculative_decoding/README.md
+++ b/examples/speculative_decoding/README.md
@@ -1,136 +1,142 @@
# Speculative Decoding
-Large Language Models (LLMs) have demonstrated remarkable capabilities and are increasingly applied in various domains. However, their text generation process is costly and slow. This inefficiency is attributed to the nature of auto-regressive decoding: each token generation necessitates a forward pass, requiring access to the entire parameter set of the LLM. This results in a memory-bound limitation for auto-regressive decoding. To accelerate auto-regressive decoding, speculative decoding methods use a draft model (either a smaller model or the LLM itself) to guess the next γ tokens through standard auto-regressive generation. Subsequently, the original LLM validates these guessed tokens, necessitating only a single forward pass for verification. If the draft model accurately predicts α tokens, a single forward pass of the original LLM can generate α+1 tokens.
+[](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/5_speculative_decoding.html)
-This section focuses on the end-to-end workflow of training speculative decoding modules to deploy for your model.
+Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving throughput.
-In this example, the end-to-end workflow of speculative decoding is demonstrated for a pretrained HF text generation model.
+This folder contains end-to-end runnable speculative decoding fine-tuning pipeline where Llama3.2-1B from huggingface is trained on Daring-Anteater dataset.
+
+This example focus on training with HF. To train with Megatron-LM, please refer to [this link](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt) in Megatron-LM repo.
+
+## Contents
-| **Section** | **Description** | **Link** | **Docs** |
-| :------------: | :------------: | :------------: | :------------: |
-| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
-| Getting Started | Learn how to optimize your models using PTQ to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/5_speculative_decoding.html)\] |
-| Support Matrix | View the support matrix to see speculation technique support | \[[Link](#support-matrix)\] | |
-| End to End | Example scripts demonstrating how to train speculation modules using Hugging Face / NeMo / Megatron-LM models | \[[Link](#end-to-end-speculative-decoding-examples)\] | |
-| Deployment | Next steps after speculation module is trained | \[[Link](#deployment)\] | |
-| Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/1_quantization.html)\] |
-| Resources | Extra links to relevant resources | \[[Link](#resources)\] | |
+| **Section** | **Description** | **Jump To** |
+| :------------: | :------------: | :------------: |
+| Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] |
+| Simplified Workflow | Train, evaluate and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
+| Complete Workflow | Full example with configurable traininig pipeline | \[[Link](#support-matrix)\] |
+| Support Matrix | Supported models for speculative decoding training | \[[Link](#deployment)\] |
+| Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] |
+| Resources | Extra links to relevant resources | \[[Link](#resources)\] |
## Pre-Requisites
-### HF
-
-Install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example:
+Install Modelopt with `hf` dependencies and other requirements for this example:
```bash
pip install nvidia-modelopt[hf]
pip install -r requirements.txt
```
-### NeMo / Megatron-LM
-
-Use the NeMo container `nvcr.io/nvidia/nemo:25.07` or later which has all the dependencies installed.
-
-## Getting Started
+We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Download by:
-### Prepare Data
-
-In speculative decoding fine-tuning, extra speculative decoding module, like Medusa heads or EAGLE module, are added to the base model to predict the next γ tokens. These tokens will then be validated by the original LLM. In order for these predicted tokens to be accepted by the original LLM, their prediction distributions should be similar to that of the base model. Therefore, we need to prepare fine-tuning data generated from the original LLM. Start by launching an inference server that will run the base model. Let us use TinyLlama/TinyLlama-1.1B-Chat-v1.0 as an example.
+```bash
+git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
+```
-First, set up a vllm server with TinyLlama. Make sure to use a different docker container other than the one for training as installing vllm may cause version conflicts with modelopt. Note: for quantized models by ModelOpt, you need to add --quantization=modelopt flag.
+## Getting Started: Simplified Workflow
-```sh
-pip install vllm
-vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --api-key token-abc123 --port 8000 --tensor-parallel-size 1
+```bash
+bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
```
-Then, we adapt the fine-tuning data by calling this server. In this example, we use Daring-Anteater dataset.
+This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
-```sh
-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
-python3 server_generate.py --data_path Daring-Anteater/train.jsonl --output_path finetune/data.jsonl --max_token 512 --chat
-```
+- Initializes the draft model with [default settings](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18)
+- Fine-tunes the model on the [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset
+- Evaluates the acceptance rate on [MT-Bench](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)
+- Exports a checkpoint ready for deployment
+- Run an interactive inference demo with vLLM using provided prompt
-To add a system prompt, use the `--system_prompt` argument:
+## Complete Workflow
-```sh
-python3 server_generate.py --data_path Daring-Anteater/train.jsonl --output_path finetune/data.jsonl --max_token 512 --chat --system_prompt
-```
+This section presents a more comprehensive example for customizing speculative decoding training with Modelopt, including optional steps to enhance training quality and efficiency.
-#### SLURM Prepare Data
+### (Optional) Data Synthesis
-For basic parallelization of synthetic data generation we provide some SLURM support.
-Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.
+To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data, ensuring that the draft model’s output distribution closely aligns with that of the base model.
-Example of allocating 4 nodes for 120 minutes
+To prepare such data, we launch an inference server with the base model:
-```sh
-salloc -N4 -A -p -J -synthetic:data-gen -t 120
+```bash
+pip install vllm
+vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000 --tensor-parallel-size 1
```
-Create shards of some given size
+Note: Add `--quantization=modelopt` flag for quantized models.
-```sh
-python3 distributed_generate/sharding_utils.py --input_path /data/train.jsonl --output_dir /data/train/ --max_lines_per_shard 10000
+Then, we generate conversations with base model and prompts from Daring-Anteater:
+
+```bash
+python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl
```
-Run workers on SLURM
+### (Optional) Draft Vocabulary Compression
-```sh
-bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 0 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
+We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
+
+```bash
+python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl ----draft_vocab_size 32000 --save_dir draft_vocab_cache
```
-`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
-This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
-In this case, the first 40 shards of data will be processed.
-To process the next 40 shards
+This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later.
-```sh
-bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
-```
+### (Optional) Configuring Draft Model
-To combine the shards back
+For eagle1 and eagle3 we provide an [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18) in modelopt. User can overwrite default settings by providing additional json dict. In this example, we overwrite the `draft_vocab_size` by in `eagle_config.json`:
-```sh
-python3 distributed_generate/sharding_utils.py --input_dir /data/output/ --output_path /data/output.jsonl --combine
+```json
+{
+ "draft_vocab_size": 32000
+}
```
-### Speculative Decoding Example Training Workflow
+### Training Draft Model with Modelopt
-Here is the recommended end-to-end speculative decoding training workflow:
+`main.py` provides a example for converting a base HF model for speculative decoding and training it. It consists of a few simple steps:
+First, load base model and tokenzier from hugginface:
```python
-import os
-import torch
-import transformers
-import modelopt.torch.opt as mto
-import modelopt.torch.speculative as mtsp
-
-# Create a base model
model = transformers.AutoModelForCausalLM.from_pretrained(
- "",
+ ""
)
+```
-if mode == "medusa":
- config = {
- "medusa_num_heads": 2,
- "medusa_num_layers": 1,
- }
-elif mode == "eagle":
- config = {
- "eagle_num_layers": 1,
- "use_input_layernorm_in_first_layer": True,
- "use_last_layernorm": False
+Then, load default eagle config and make necessary overwrites:
+
+```python
+# Load default config
+config = {
+ "eagle1": EAGLE1_DEFAULT_CFG,
+ "eagle3": EAGLE3_DEFAULT_CFG,
+}[training_args.mode]["config"]
+
+# overwrite config with custom config
+config["eagle_architecture_config"].update({"": ""})
+
+# Mandatory: hidden size, vocab size and max position embeddings must match base model
+config["eagle_architecture_config"].update(
+ {
+ "hidden_size": model.config.hidden_size,
+ "vocab_size": model.config.vocab_size,
+ "max_position_embeddings": model.config.max_position_embeddings,
}
-mtsp.convert(model, [(mode, config)])
+)
+```
+
+Then, we convert model to a speculative deocoding model:
+
+```python
+mtsp.convert(model, [("eagle", config)])
+```
-tokenizer = transformers.AutoTokenizer.from_pretrained(ckpt_path)
-tokenizer.pad_token_id = tokenizer.eos_token_id
+This will modify the model in-place with eagle training forward, making it compatible with HF trainer:
+```python
# Create a trainer
trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer._move_model_to_device(model, trainer.args.device)
@@ -143,81 +149,87 @@ trainer.save_state()
trainer.save_model("")
```
-## Support Matrix
+We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch the training with huggingface accelrate in `launch_train.sh`, which can be runned by:
-### Supported Models/Techniques
+```bash
+./launch_train.sh --model $BASE_MODEL \
+ --output_dir $OUTPUT_DIR \
+ --data $DATA \
+ --num_gpu $NUM_GPU \
+ --num_epochs 10 \
+ --eagle_config eagle_config.json #This is where we overwrite default eagle configs
+```
-#### NeMo/Megatron-LM
+The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
-| Model | Medusa | EAGLE1/2 | EAGLE3 |
-| :---: | :---: | :---: | :---: |
-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
-| Mistral | ✅ | ✅ | ✅ |
-| Phi 3 | ✅ | ✅ | ✅ |
-| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
+### Model Validation
-#### Hugging Face
+After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
-| Model | Medusa | EAGLE1/2 | EAGLE3 |
-| :---: | :---: | :---: | :---: |
-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
-| Mistral | ✅ | ✅ | ✅ |
-| Phi 3 | ✅ | ✅ | ✅ |
-| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
+```python
+python ar_validate.py --model_path $OUTPUT_DIR
+```
-### End-to-end Speculative Decoding Examples
+Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
-### MLM Example
+### Export
-
+```python
+python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
+```
-
+### Deployment
-### HuggingFace
+The exported checkpoint can be deployed on TRT-LLM or vLLM.
-This folder contains end-to-end runnable speculative decoding fine-tuning pipeline where TinyLlama from huggingface is trained on Daring-Anteater dataset.
+#### TRT-LLM
-First, download the data:
+To serve the checkpoint with trtllm, we can run trtllm-serve with:
-```sh
-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
+```bash
+trtllm-serve --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml
```
-Then, prepare the synthesized data from the base model. Please refer to the **Prepare Data** section.
+, with `extra-llm-api-config.yml` being
+
+```yaml
+enable_attention_dp: false
+disable_overlap_scheduler: true
+enable_autotuner: false
-Next, we fine-tune the speculative decoding models with the base model frozen. Here is the command for Medusa and EAGLE:
+cuda_graph_config:
+ max_batch_size: 1
-```sh
-./launch.sh --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
- --data finetune/data.jsonl \
- --mode medusa \
- --num_epochs 1 --lr 1e-5 --save_steps 1000 \
- --output_dir medusa-tinyllama \
- --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
- --num_gpu 1 \
- --medusa_num_heads 2 --medusa_num_layers 1
+speculative_config:
+ decoding_type: Eagle
+ max_draft_len: 3
+ speculative_model_dir:
-./launch.sh --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
- --data finetune/data.jsonl \
- --mode eagle \
- --num_epochs 1 --lr 1e-5 --save_steps 1000 \
- --output_dir eagle-tinyllama \
- --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
- --num_gpu 1 \
- --eagle_num_layers 1
+kv_cache_config:
+ enable_block_reuse: false
```
-This will generate fine-tuned checkpoints in `output_dir` specified above.
+Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage.
-Alternatively, you can refer to this [notebook](example.ipynb).
+#### vLLM
-### Deployment
+Please refer to [vLLM Doc: Speculative Decoding](https://docs.vllm.ai/en/v0.9.0/features/spec_decode.html) for detailed usage.
+
+#### Deploying Quantized model
-The final model after end-to-end speculative decoding fine-tuning is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. It can be deployed to TensorRT-LLM (TRTLLM) or to TensorRT just like a regular **ModelOpt** model. See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).
+See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).
+
+## Support Matrix
+
+| Model | Medusa | EAGLE1/2 | EAGLE3 |
+| :---: | :---: | :---: | :---: |
+| LLAMA 2 | ✅ | ✅ | ✅ |
+| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
+| Mistral | ✅ | ✅ | ✅ |
+| Phi 3 | ✅ | ✅ | ✅ |
+| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
## Speculation Module Checkpoints
diff --git a/examples/speculative_decoding/ar_validate.py b/examples/speculative_decoding/ar_validate.py
index bfbcb2239..38b886693 100644
--- a/examples/speculative_decoding/ar_validate.py
+++ b/examples/speculative_decoding/ar_validate.py
@@ -26,7 +26,7 @@
mto.enable_huggingface_checkpointing()
-def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None):
+def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
validator = HFARValidation(model, tokenizer)
num_samples = min(num_samples, len(ds))
ars = []
@@ -54,12 +54,12 @@ def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=No
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
- parser.add_argument("--steps", type=int, default=1, help="Steps for AR validation")
+ parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation")
parser.add_argument(
- "--osl", type=int, default=100, help="Output sequence length for AR validation"
+ "--osl", type=int, default=32, help="Output sequence length for AR validation"
)
parser.add_argument(
- "--num_samples", type=int, default=20, help="Number of MT-Bench samples to use"
+ "--num_samples", type=int, default=80, help="Number of MT-Bench samples to use"
)
parser.add_argument(
"--ar_lower_bound",
diff --git a/examples/speculative_decoding/calibrate_draft_vocab.py b/examples/speculative_decoding/calibrate_draft_vocab.py
index 1211d4eb6..37a798cf8 100644
--- a/examples/speculative_decoding/calibrate_draft_vocab.py
+++ b/examples/speculative_decoding/calibrate_draft_vocab.py
@@ -28,11 +28,10 @@ def main():
parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer")
parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)")
parser.add_argument(
- "--eagle_config",
- type=str,
+ "--draft_vocab_size",
+ type=int,
required=True,
- default="eagle_config.json",
- help="Path to eagle_config.json",
+ help="Draft vocab size",
)
parser.add_argument(
"--calibrate_size",
@@ -45,12 +44,6 @@ def main():
)
args = parser.parse_args()
- with open(args.eagle_config) as f:
- eagle_config = json.load(f)
- if "draft_vocab_size" not in eagle_config:
- print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.")
- return
-
print("Calibrating vocab...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
with open(args.data) as f:
@@ -59,7 +52,7 @@ def main():
conversations = conversations[: args.calibrate_size]
conversations = [item for sublist in conversations for item in sublist]
- d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"])
+ d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size)
model_name = os.path.basename(os.path.normpath(args.model))
vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt")
os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
diff --git a/examples/speculative_decoding/eagle_config.json b/examples/speculative_decoding/eagle_config.json
index 55ff948ed..b4b218fdf 100644
--- a/examples/speculative_decoding/eagle_config.json
+++ b/examples/speculative_decoding/eagle_config.json
@@ -1,3 +1,10 @@
{
- "draft_vocab_size": 32000
+ "rope_scaling": {
+ "factor": 32.0,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "initializer_range": 0.02
}
diff --git a/examples/speculative_decoding/export_hf_checkpoint.py b/examples/speculative_decoding/export_hf_checkpoint.py
new file mode 100644
index 000000000..2195b6e93
--- /dev/null
+++ b/examples/speculative_decoding/export_hf_checkpoint.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+
+import torch
+from transformers import AutoModelForCausalLM
+
+import modelopt.torch.opt as mto
+from modelopt.torch.export import export_hf_checkpoint
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_path", type=str, default="")
+ parser.add_argument("--export_path", type=str, default="")
+ return parser.parse_args()
+
+
+mto.enable_huggingface_checkpointing()
+
+args = parse_args()
+model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
+with torch.inference_mode():
+ export_hf_checkpoint(
+ model, # The quantized model.
+ export_dir=args.export_path, # The directory where the exported files will be stored.
+ )
+print(f"Exported checkpoint to {args.export_path}")
diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py
index 28e388b4d..aae4177c6 100644
--- a/examples/speculative_decoding/main.py
+++ b/examples/speculative_decoding/main.py
@@ -47,6 +47,13 @@
import modelopt.torch.speculative as mtsp
from modelopt.torch.utils import print_rank_0
+try:
+ import wandb
+
+ wandb.init()
+except ImportError:
+ wandb = None
+
torch.manual_seed(0)
mto.enable_huggingface_checkpointing()
@@ -170,6 +177,8 @@ def train():
{
"hidden_size": model.config.hidden_size,
"vocab_size": model.config.vocab_size,
+ # we also overwrite max_pos_embedding for deployment compatibility
+ "max_position_embeddings": model.config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else model.config.vocab_size,
@@ -213,6 +222,8 @@ def on_step_end(self, args, state, control, **kwargs):
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
+ if wandb:
+ wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
return control
trainer = Trainer(
diff --git a/examples/speculative_decoding/server_generate.py b/examples/speculative_decoding/server_generate.py
index 4541ecf42..0fb71a0a0 100644
--- a/examples/speculative_decoding/server_generate.py
+++ b/examples/speculative_decoding/server_generate.py
@@ -46,7 +46,7 @@
parser.add_argument(
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
)
-parser.add_argument("--chat", action="store_true", help="Use chat mode")
+parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
parser.add_argument("--model", type=str, default="model", help="Model name")
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")
diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh
new file mode 100644
index 000000000..5ead19202
--- /dev/null
+++ b/examples/speculative_decoding/train_eagle3_and_export.sh
@@ -0,0 +1,72 @@
+#!/bin/bash
+
+# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+# Set default values for BASE_MODEL, NUM_GPU, and DATA
+BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
+NUM_GPU=1
+DATA=Daring-Anteater/train.jsonl
+
+# Parse input arguments --base-model, --num_gpu, and --data
+while [[ $# -gt 0 ]]; do
+ key="$1"
+ case $key in
+ --base_model)
+ BASE_MODEL="$2"
+ shift; shift
+ ;;
+ --num_gpu)
+ NUM_GPU="$2"
+ shift; shift
+ ;;
+ --data)
+ DATA="$2"
+ shift; shift
+ ;;
+ *)
+ echo "Unknown argument: $1"
+ exit 1
+ ;;
+ esac
+done
+
+
+if [[ "$NUM_GPU" == 1 ]]; then
+ export CUDA_VISIBLE_DEVICES=0
+else
+ # Export as 0,1,...,N-1 for NUM_GPU GPUs
+ export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPU-1)))
+fi
+
+MODEL_BASENAME=$(basename "$BASE_MODEL")
+
+echo "==== [1/3] Training draft model ===="
+OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+./launch_train.sh --model $BASE_MODEL \
+ --output_dir $OUTPUT_DIR \
+ --data $DATA \
+ --num_gpu $NUM_GPU \
+ --num_epochs 2 \
+ --eagle_config eagle_config.json
+
+echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ===="
+python ar_validate.py --model_path $OUTPUT_DIR
+
+echo "==== [3/3] Exporting checkpoint to deployment format ===="
+EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
diff --git a/modelopt/torch/export/plugins/__init__.py b/modelopt/torch/export/plugins/__init__.py
index a7bbc3fb3..8f90fc528 100644
--- a/modelopt/torch/export/plugins/__init__.py
+++ b/modelopt/torch/export/plugins/__init__.py
@@ -19,3 +19,5 @@
with import_plugin("megatron_importer"):
from .megatron_importer import *
+with import_plugin("transformers"):
+ from .hf_spec_export import *
diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py
new file mode 100644
index 000000000..d04d9b1a0
--- /dev/null
+++ b/modelopt/torch/export/plugins/hf_spec_export.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Modifiy stated_dict and config for exporting speculative decoding in official format."""
+
+import torch
+import torch.nn as nn
+
+from modelopt.torch.speculative.plugins.transformers import HFEagleModel
+
+SPECULATIVE_DECODING_MODES = ["eagle", "medusa"]
+
+EALGE_MODELOPT_TO_OFFICIAL = {
+ "required": {
+ "layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight",
+ "layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight",
+ "layers.0.self_attn.v_proj.weight": "midlayer.self_attn.v_proj.weight",
+ "layers.0.self_attn.o_proj.weight": "midlayer.self_attn.o_proj.weight",
+ "layers.0.mlp.gate_proj.weight": "midlayer.mlp.gate_proj.weight",
+ "layers.0.mlp.up_proj.weight": "midlayer.mlp.up_proj.weight",
+ "layers.0.mlp.down_proj.weight": "midlayer.mlp.down_proj.weight",
+ "hidden_norm.weight": "midlayer.hidden_norm.weight",
+ "input_embeds_norm.weight": "midlayer.input_layernorm.weight",
+ "layers.0.post_attention_layernorm.weight": "midlayer.post_attention_layernorm.weight",
+ "norm.weight": "norm.weight",
+ "fc.weight": "fc.weight",
+ },
+ "optional": {
+ "d2t": "d2t",
+ "eagle_lm_head.weight": "lm_head.weight",
+ },
+}
+
+
+def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
+ """Check if the state dict keys match."""
+ draft_keys = set(draft_model.state_dict().keys())
+ for required_key in required_items:
+ if required_key not in draft_keys:
+ raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")
+
+
+def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
+ """Only return the state dict of the draft model in official format and ignore the base model."""
+ # check the model has only speculative decoding
+ opt_modes = model._modelopt_state
+ if len(opt_modes) != 1 or opt_modes[0][0] != "eagle":
+ # if there's other opts, return as is
+ return post_state_dict
+
+ assert isinstance(model, HFEagleModel)
+ # Check if the state dict keys match
+ _check_state_dict_keys_match(model.eagle_module, EALGE_MODELOPT_TO_OFFICIAL["required"])
+
+ # Convert key names and save the state dict
+ export_state_dict = {}
+ for ours_key, export_key in {
+ **EALGE_MODELOPT_TO_OFFICIAL["required"],
+ **EALGE_MODELOPT_TO_OFFICIAL["optional"],
+ }.items():
+ if ours_key in model.eagle_module.state_dict():
+ export_state_dict[export_key] = model.eagle_module.state_dict()[ours_key]
+
+ # TODO: (hg) this is a temp fix. Find cleaner way to do this.
+ if "eagle_lm_head.weight" not in model.eagle_module.state_dict():
+ export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
+
+ return export_state_dict
+
+
+def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
+ """Return the config of draft model in official format."""
+ if len(model._modelopt_state) != 1 or model._modelopt_state[0][0] != "eagle":
+ # return as is
+ return config_data
+
+ assert isinstance(model, HFEagleModel)
+
+ # This is the config keys in official checkpoint.
+ template_config = {
+ "architectures": ["LlamaForCausalLM"],
+ "bos_token_id": None,
+ "eos_token_id": None,
+ "hidden_act": None,
+ "hidden_size": None,
+ "initializer_range": None,
+ "intermediate_size": None,
+ "max_position_embeddings": None,
+ "model_type": "llama",
+ "num_attention_heads": None,
+ "num_key_value_heads": None,
+ "num_hidden_layers": None,
+ "pad_token_id": None,
+ "rms_norm_eps": None,
+ "tie_word_embeddings": False,
+ "torch_dtype": None,
+ "transformers_version": None,
+ "use_cache": None,
+ "vocab_size": None,
+ "draft_vocab_size": None,
+ "rope_scaling": None,
+ "attention_bias": None,
+ "attention_dropout": None,
+ "head_dim": None,
+ "mlp_bias": None,
+ "pretraining_tp": None,
+ "rope_theta": None,
+ "eagle_config": {
+ "eagle_aux_hidden_state_layer_ids": None,
+ "use_aux_hidden_state": None,
+ "use_input_layernorm_in_first_layer": None,
+ "use_last_layernorm": None,
+ "use_mtp_layernorm": None,
+ },
+ }
+
+ def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
+ if getattr(model.eagle_config, key, None) is not None:
+ return getattr(model.eagle_config, key)
+ elif getattr(model.config, key, None) is not None:
+ return getattr(model.config, key)
+ else:
+ return None
+
+ for key in template_config:
+ value = template_config[key]
+ if isinstance(value, dict):
+ # for eagle config, we find it in model.eagle_config
+ for sub_key in value:
+ value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model)
+ elif value is None:
+ # First, we try to load fron eagle config.
+ new_value = _get_config_from_eagle_config_or_base_config(key, model)
+ # If the value is a torch.dtype, we convert to string for serialization.
+ if isinstance(new_value, torch.dtype):
+ new_value = str(new_value).replace("torch.", "")
+ template_config[key] = new_value
+
+ return template_config
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
old mode 100755
new mode 100644
index b18ae2619..e50092e56
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -27,6 +27,10 @@
import torch
import torch.nn as nn
+from modelopt.torch.export.plugins import (
+ rename_and_prune_if_spec_decoding,
+ set_config_if_spec_decoding,
+)
from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import NVFP4QTensor
@@ -490,6 +494,14 @@ def _export_hf_checkpoint(
return quantized_state_dict, quant_config
+def _quant_applied(hf_quant_config: dict) -> bool:
+ """Check if any quantization is applied."""
+ return not (
+ hf_quant_config["quantization"]["quant_algo"] == QUANTIZATION_NONE
+ and not hf_quant_config["quantization"]["quantized_layers"]
+ )
+
+
def export_hf_checkpoint(
model: nn.Module,
dtype: torch.dtype | None = None,
@@ -509,12 +521,16 @@ def export_hf_checkpoint(
try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
- # Save hf_quant_config.json for backward compatibility
- with open(f"{export_dir}/hf_quant_config.json", "w") as file:
- json.dump(hf_quant_config, file, indent=4)
+ # When there's no quantization applied, we avoid saving hf_quant_config.json for compatibility
+ if _quant_applied(hf_quant_config):
+ # Save hf_quant_config.json for backward compatibility
+ with open(f"{export_dir}/hf_quant_config.json", "w") as file:
+ json.dump(hf_quant_config, file, indent=4)
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
+ post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
+
# Save model
model.save_pretrained(
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
@@ -528,6 +544,8 @@ def export_hf_checkpoint(
config_data["quantization_config"] = hf_quant_config
+ config_data = set_config_if_spec_decoding(model, config_data)
+
with open(original_config, "w") as file:
json.dump(config_data, file, indent=4)
diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py
index ff425d63f..0ef40c529 100644
--- a/modelopt/torch/speculative/plugins/transformers.py
+++ b/modelopt/torch/speculative/plugins/transformers.py
@@ -182,17 +182,6 @@ def __init__(self, config, decoder_layer_cls, bias=False):
"""Init function for EagleModule."""
super().__init__()
self.config = config
-
- # NOTE:This is a temporary fix to support Qwen and Mixtral in current release.
- # This is refactored in following MR.
- config_overwrite = {
- "mlp_bias": False,
- "attention_bias": False,
- "head_dim": self.config.hidden_size // self.config.num_attention_heads,
- }
- for key, value in config_overwrite.items():
- setattr(self.config, key, value)
-
self.layers = nn.ModuleList(
[decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@@ -696,8 +685,10 @@ def _base_model_forward(
past_key_values,
freeze_base_model,
labels,
- kwargs,
+ **kwargs,
):
+ # TODO: This function still use eagle_module. Ideally we should remove it,
+ # so we can del model.eagle_module on the base model ranks to save memory.
with torch.no_grad() if freeze_base_model else contextlib.nullcontext():
outputs = super().forward(
input_ids=input_ids,
@@ -722,15 +713,18 @@ def _base_model_forward(
base_model_loss = loss_fct(loss_logits, labels)
# Map the base model logits to the draft vocab
- if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training:
- reverse_mapping = (
- torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
- + self.eagle_module.d2t
- )
- base_model_logits = base_model_logits[:, :, reverse_mapping]
+ if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
+ base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values
+ def _map_logits_to_draft_vocab(self, full_logits):
+ reverse_mapping = (
+ torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
+ + self.eagle_module.d2t
+ )
+ return full_logits[:, :, reverse_mapping]
+
def _eagle_forward(
self,
eagle_input_hidden_states,
@@ -795,17 +789,31 @@ def forward(
loss_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device)
# ====First, we run base model forward====
- base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = (
- self._base_model_forward(
- input_ids,
- attention_mask,
- position_ids,
- past_key_values,
- self.eagle_freeze_base_model,
- labels,
- kwargs,
+ if "base_model_outputs" in kwargs:
+ # Parse base model outputs forwarded from teacher
+ base_outputs = kwargs["base_model_outputs"]
+ base_model_hidden_states = base_outputs["base_model_hidden_states"]
+ if "base_model_logits" in base_outputs:
+ base_model_logits = base_outputs["base_model_logits"]
+ else:
+ base_model_logits = self.lm_head(base_model_hidden_states)
+ if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
+ base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
+ base_model_loss = None
+ past_key_values = None
+
+ else:
+ base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = (
+ self._base_model_forward(
+ input_ids,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ self.eagle_freeze_base_model,
+ labels,
+ **kwargs,
+ )
)
- )
# ====Run eagle forward====
eagle_loss = None
@@ -813,9 +821,11 @@ def forward(
# In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers
batch_size, seq_length, _ = base_model_hidden_states.shape
if self.eagle_config.use_aux_hidden_state:
- eagle_input_hidden_states = self.eagle_module.fc(
- torch.cat(self.pop_aux_hidden_states(), dim=-1)
- )
+ if "base_model_outputs" in kwargs:
+ aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"]
+ else:
+ aux_hidden_states = torch.cat(self.pop_aux_hidden_states(), dim=-1)
+ eagle_input_hidden_states = self.eagle_module.fc(aux_hidden_states)
else:
eagle_input_hidden_states = base_model_hidden_states
@@ -842,10 +852,12 @@ def forward(
if not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
- past_key_values.eagle_cache = eagle_cache
+
+ # NOTE: diabled for now.
+ # past_key_values.eagle_cache = eagle_cache
# Compute loss on the eagle modules
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_0 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[:, :-1],
@@ -879,7 +891,7 @@ def forward(
position_embeddings,
)
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_1 = self._eagle_loss(
# base model predict +1 tok, while eagle predict +2
# so we shift base model outputs compared to eagle outputs
base_model_hidden_states[:, 1:],
@@ -927,7 +939,7 @@ def forward(
position_embeddings,
)
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_2 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[:, -seq_length:-1, :],
@@ -969,7 +981,7 @@ def forward(
position_embeddings,
)
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_3 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[
@@ -1006,11 +1018,14 @@ def forward(
"Both base_model_loss and eagle_loss are skipped. At least one loss must be computed."
)
+ train_acc = (accuracy_0, accuracy_1, accuracy_2, accuracy_3) if self.training else None
+
return ModelOutput(
loss=loss,
logits=base_model_logits,
past_key_values=past_key_values,
hidden_states=base_model_hidden_states,
+ train_acc=train_acc,
)
def _eagle_loss(
@@ -1034,7 +1049,13 @@ def _eagle_loss(
regression_loss = torch.sum(torch.mean(loss_mask * regression_loss, 2)) / (
loss_mask.sum() + 1e-5
)
- return regression_loss, classification_loss
+ base_predict_tok = base_model_logits.argmax(dim=-1)
+ eagle_predict_tok = eagle_logits.argmax(dim=-1)
+ accuracy = (
+ (loss_mask[:, :, 0] * (base_predict_tok == eagle_predict_tok)).float().mean().item()
+ )
+ accuracy = round(accuracy, 3)
+ return regression_loss, classification_loss, accuracy
@torch.no_grad()
def pseudo_speculative_generate(
@@ -1055,7 +1076,7 @@ def pseudo_speculative_generate(
base_model_hidden_states = base_model_outputs.hidden_states[-1]
base_model_logits = base_model_outputs.logits
- base_token = base_model_logits[:, -1:, :].argmax(dim=-1)
+ base_token = base_model_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device)
# Early return
if steps < 1:
@@ -1132,7 +1153,7 @@ def get_ground_truth(self, input_ids, osl):
input_ids = copy.deepcopy(input_ids).to(torch.cuda.current_device())
for _ in range(osl):
input_id, _ = self.model.pseudo_speculative_generate(input_ids, steps=0)
- input_ids = torch.cat((input_ids, input_id), dim=-1)
+ input_ids = torch.cat((input_ids, input_id.to(input_ids.device)), dim=-1)
if input_id[0, 0] == self.end_token:
break
return input_ids
From 4ffe0047d8f226f36c32756864ad9ec17c147974 Mon Sep 17 00:00:00 2001
From: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Date: Fri, 5 Sep 2025 19:08:39 +0000
Subject: [PATCH 2/7] add file; only d2t when training
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
---
examples/speculative_decoding/launch_train.sh | 157 ++++++++++++++++++
.../torch/speculative/plugins/transformers.py | 2 +-
2 files changed, 158 insertions(+), 1 deletion(-)
create mode 100755 examples/speculative_decoding/launch_train.sh
diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh
new file mode 100755
index 000000000..cf54f9446
--- /dev/null
+++ b/examples/speculative_decoding/launch_train.sh
@@ -0,0 +1,157 @@
+#!/bin/bash
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+while [ $# -gt 0 ]; do
+ case "$1" in
+ --training_seq_len*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ TRAINING_SEQ_LEN="${1#*=}"
+ ;;
+ --model*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MODEL="${1#*=}"
+ ;;
+ --data*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ DATA="${1#*=}"
+ ;;
+ --mode*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MODE="${1#*=}"
+ ;;
+ --output_dir*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ OUTPUT_DIR="${1#*=}"
+ ;;
+ --num_epochs*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ NUM_EPOCHS="${1#*=}"
+ ;;
+ --save_steps*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ SAVE_STEPS="${1#*=}"
+ ;;
+ --lr*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ LR="${1#*=}"
+ ;;
+ --train_bs*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ TRAIN_BS="${1#*=}"
+ ;;
+ --medusa_num_heads*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MEDUSA_NUM_HEADS="${1#*=}"
+ ;;
+ --medusa_num_layers*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MEDUSA_NUM_LAYERS="${1#*=}"
+ ;;
+ --eagle_config*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ EAGLE_CONFIG="${1#*=}"
+ ;;
+ --fsdp_transformer_layer_cls_to_wrap*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
+ ;;
+ --num_gpu*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ NUM_GPU="${1#*=}"
+ ;;
+ *)
+ >&2 printf "Error: Invalid argument ${1#*=}\n"
+ exit 1
+ ;;
+ esac
+ shift
+done
+
+set -x
+
+# Get the default value for save_steps based on the available number of GPUs
+GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
+# Calculate save_steps
+DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+
+MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
+MODE=${MODE:-"eagle3"}
+# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path
+MODEL_BASENAME=$(basename "$MODEL")
+OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
+NUM_EPOCHS=${NUM_EPOCHS:-1}
+SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
+LR=${LR:-"1e-4"}
+TRAIN_BS=${TRAIN_BS:-4}
+MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
+MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
+REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
+REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
+FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
+NUM_GPU=${NUM_GPU:-1}
+TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512}
+
+if [[ "$MODE" == "medusa" ]]; then
+ SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
+elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then
+ if [[ -n "$EAGLE_CONFIG" ]]; then
+ SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
+ else
+ SPECULATIVE_ARGS=""
+ fi
+else
+ echo "Only medusa, eagle1, eagle3 supported for now!"
+ exit 1
+fi
+
+if [[ "$NUM_GPU" == 1 ]]; then
+ MULTI_GPU=""
+else
+ MULTI_GPU="--multi_gpu"
+fi
+
+# Disable tokenizers parallelism to avoid warning
+export TOKENIZERS_PARALLELISM=False
+CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+ --mode $MODE \
+ --model_name_or_path $MODEL \
+ --training_seq_len $TRAINING_SEQ_LEN \
+ --dataloader_drop_last True \
+ --bf16 True \
+ --output_dir $OUTPUT_DIR \
+ --num_train_epochs $NUM_EPOCHS \
+ --per_device_train_batch_size $TRAIN_BS \
+ --per_device_eval_batch_size $TRAIN_BS \
+ --gradient_accumulation_steps 1 \
+ --do_eval False \
+ --eval_accumulation_steps 1 \
+ --save_strategy steps \
+ --save_steps $SAVE_STEPS \
+ --learning_rate $LR \
+ --weight_decay 0.0 \
+ --warmup_steps 100 \
+ --lr_scheduler_type linear \
+ --logging_steps 100 \
+ --tf32 True \
+ --data_path $DATA \
+ $SPECULATIVE_ARGS
+"
+
+start_time=$(date +%s)
+sh -c "$CMD"
+echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py
index 0ef40c529..1e4c3a73d 100644
--- a/modelopt/torch/speculative/plugins/transformers.py
+++ b/modelopt/torch/speculative/plugins/transformers.py
@@ -713,7 +713,7 @@ def _base_model_forward(
base_model_loss = loss_fct(loss_logits, labels)
# Map the base model logits to the draft vocab
- if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
+ if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values
From bfc8879c87d72dec8319c795279e6147ddc5e697 Mon Sep 17 00:00:00 2001
From: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Date: Sat, 6 Sep 2025 01:05:32 +0000
Subject: [PATCH 3/7] address review comments
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
---
examples/speculative_decoding/README.md | 51 ++++++++++---------
.../SLURM_prepare_data.md | 37 ++++++++++++++
.../export_hf_checkpoint.py | 13 +++--
.../train_eagle3_and_export.sh | 7 ++-
.../torch/export/plugins/hf_spec_export.py | 32 +++++++-----
modelopt/torch/export/unified_export_hf.py | 15 +++---
.../torch/speculative/plugins/transformers.py | 23 +++++----
7 files changed, 116 insertions(+), 62 deletions(-)
create mode 100644 examples/speculative_decoding/SLURM_prepare_data.md
diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md
index 910077a58..9a3d11f3a 100644
--- a/examples/speculative_decoding/README.md
+++ b/examples/speculative_decoding/README.md
@@ -2,11 +2,11 @@
[](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/5_speculative_decoding.html)
-Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving throughput.
+Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving generation speed.
-This folder contains end-to-end runnable speculative decoding fine-tuning pipeline where Llama3.2-1B from huggingface is trained on Daring-Anteater dataset.
+This folder contains an end-to-end runnable speculative decoding fine‑tuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the Daring‑Anteater dataset.
-This example focus on training with HF. To train with Megatron-LM, please refer to [this link](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt) in Megatron-LM repo.
+This example focuses on training with Hugging Face. To train with Megatron‑LM, see the [Megatron‑LM example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt).
## Contents
@@ -15,9 +15,9 @@ This example focus on training with HF. To train with Megatron-LM, please refer
| **Section** | **Description** | **Jump To** |
| :------------: | :------------: | :------------: |
| Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] |
-| Simplified Workflow | Train, evaluate and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
-| Complete Workflow | Full example with configurable traininig pipeline | \[[Link](#support-matrix)\] |
-| Support Matrix | Supported models for speculative decoding training | \[[Link](#deployment)\] |
+| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
+| Complete Workflow | Full example with configurable training pipeline | \[[Link](#complete-workflow)\] |
+| Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] |
| Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] |
| Resources | Extra links to relevant resources | \[[Link](#resources)\] |
@@ -50,7 +50,6 @@ This one-line command runs a minimal example workflow of training and exporting
- Fine-tunes the model on the [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset
- Evaluates the acceptance rate on [MT-Bench](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)
- Exports a checkpoint ready for deployment
-- Run an interactive inference demo with vLLM using provided prompt
## Complete Workflow
@@ -75,19 +74,23 @@ Then, we generate conversations with base model and prompts from Daring-Anteater
python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl
```
+To add a system prompt, use the `--system_prompt ` argument.
+
+For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support.
+
### (Optional) Draft Vocabulary Compression
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
```bash
-python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl ----draft_vocab_size 32000 --save_dir draft_vocab_cache
+python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
```
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later.
### (Optional) Configuring Draft Model
-For eagle1 and eagle3 we provide an [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18) in modelopt. User can overwrite default settings by providing additional json dict. In this example, we overwrite the `draft_vocab_size` by in `eagle_config.json`:
+For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`:
```json
{
@@ -97,8 +100,8 @@ For eagle1 and eagle3 we provide an [default model architecture config](https://
### Training Draft Model with Modelopt
-`main.py` provides a example for converting a base HF model for speculative decoding and training it. It consists of a few simple steps:
-First, load base model and tokenzier from hugginface:
+`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps:
+First, load the base model and tokenizer from Hugging Face:
```python
model = transformers.AutoModelForCausalLM.from_pretrained(
@@ -116,7 +119,7 @@ config = {
}[training_args.mode]["config"]
# overwrite config with custom config
-config["eagle_architecture_config"].update({"": ""})
+config["eagle_architecture_config"].update({"": ""})
# Mandatory: hidden size, vocab size and max position embeddings must match base model
config["eagle_architecture_config"].update(
@@ -128,7 +131,7 @@ config["eagle_architecture_config"].update(
)
```
-Then, we convert model to a speculative deocoding model:
+Then, we convert model to a speculative decoding model:
```python
mtsp.convert(model, [("eagle", config)])
@@ -149,7 +152,7 @@ trainer.save_state()
trainer.save_model("")
```
-We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch the training with huggingface accelrate in `launch_train.sh`, which can be runned by:
+We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by:
```bash
./launch_train.sh --model $BASE_MODEL \
@@ -157,7 +160,7 @@ We omitted details like tokenizer initialization for simplicity. A complete trai
--data $DATA \
--num_gpu $NUM_GPU \
--num_epochs 10 \
- --eagle_config eagle_config.json #This is where we overwrite default eagle configs
+ --eagle_config eagle_config.json #This is where we optionally overwrite default eagle configs
```
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
@@ -166,7 +169,7 @@ The saved modelopt checkpoint is similar in architecture to HF models. It can be
After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
-```python
+```bash
python ar_validate.py --model_path $OUTPUT_DIR
```
@@ -174,19 +177,19 @@ Alternatively, we can export the checkpoint and run evaluation on serving framew
### Export
-```python
+```bash
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
```
-This will export the model from a modelopt checkpoint to a deployment-compatible formart.
+This exports the model from a ModelOpt checkpoint to a deployment‑compatible format.
### Deployment
-The exported checkpoint can be deployed on TRT-LLM or vLLM.
+The exported checkpoint can be deployed on TRT-LLM or SGLang.
#### TRT-LLM
-To serve the checkpoint with trtllm, we can run trtllm-serve with:
+To serve the checkpoint with trtllm, run trtllm-serve with:
```bash
trtllm-serve --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml
@@ -213,9 +216,9 @@ kv_cache_config:
Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage.
-#### vLLM
+#### SGLang
-Please refer to [vLLM Doc: Speculative Decoding](https://docs.vllm.ai/en/v0.9.0/features/spec_decode.html) for detailed usage.
+Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage.
#### Deploying Quantized model
@@ -233,8 +236,8 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE
## Speculation Module Checkpoints
-Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
-Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang)!\
+Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - NVIDIA TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
+Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and [SGLang](https://github.com/sgl-project/sglang)!\
More models coming soon!
## Resources
diff --git a/examples/speculative_decoding/SLURM_prepare_data.md b/examples/speculative_decoding/SLURM_prepare_data.md
new file mode 100644
index 000000000..bc3c0abb6
--- /dev/null
+++ b/examples/speculative_decoding/SLURM_prepare_data.md
@@ -0,0 +1,37 @@
+# SLURM Prepare Data
+
+For basic parallelization of synthetic data generation we provide some SLURM support.
+Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.
+
+Example of allocating 4 nodes for 120 minutes
+
+```sh
+salloc -N4 -A -p -J -synthetic:data-gen -t 120
+```
+
+Create shards of some given size
+
+```sh
+python3 distributed_generate/sharding_utils.py --input_path /data/train.jsonl --output_dir /data/train/ --max_lines_per_shard 10000
+```
+
+Run workers on SLURM
+
+```sh
+bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 0 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
+```
+
+`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
+This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
+In this case, the first 40 shards of data will be processed.
+To process the next 40 shards
+
+```sh
+bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
+```
+
+To combine the shards back
+
+```sh
+python3 distributed_generate/sharding_utils.py --input_dir /data/output/ --output_path /data/output.jsonl --combine
+```
diff --git a/examples/speculative_decoding/export_hf_checkpoint.py b/examples/speculative_decoding/export_hf_checkpoint.py
index 2195b6e93..dfc293ee9 100644
--- a/examples/speculative_decoding/export_hf_checkpoint.py
+++ b/examples/speculative_decoding/export_hf_checkpoint.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Export a HF checkpoint (with ModelOpt state) for deployment."""
+
import argparse
import torch
@@ -23,9 +25,13 @@
def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--model_path", type=str, default="")
- parser.add_argument("--export_path", type=str, default="")
+ parser = argparse.ArgumentParser(
+ description="Export a HF checkpoint (with ModelOpt state) for deployment."
+ )
+ parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
+ parser.add_argument(
+ "--export_path", type=str, default="Destination directory for exported files."
+ )
return parser.parse_args()
@@ -33,6 +39,7 @@ def parse_args():
args = parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
+model.eval()
with torch.inference_mode():
export_hf_checkpoint(
model, # The quantized model.
diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh
index 5ead19202..042a27652 100644
--- a/examples/speculative_decoding/train_eagle3_and_export.sh
+++ b/examples/speculative_decoding/train_eagle3_and_export.sh
@@ -22,7 +22,7 @@ BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
NUM_GPU=1
DATA=Daring-Anteater/train.jsonl
-# Parse input arguments --base-model, --num_gpu, and --data
+# Parse input arguments --base_model, --num_gpu, and --data
while [[ $# -gt 0 ]]; do
key="$1"
case $key in
@@ -50,13 +50,15 @@ if [[ "$NUM_GPU" == 1 ]]; then
export CUDA_VISIBLE_DEVICES=0
else
# Export as 0,1,...,N-1 for NUM_GPU GPUs
- export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPU-1)))
+ devs="$(seq -s, 0 $((NUM_GPU-1)))"
+ export CUDA_VISIBLE_DEVICES="$devs"
fi
MODEL_BASENAME=$(basename "$BASE_MODEL")
echo "==== [1/3] Training draft model ===="
OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+mkdir -p "$(dirname "$OUTPUT_DIR")"
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
@@ -69,4 +71,5 @@ python ar_validate.py --model_path $OUTPUT_DIR
echo "==== [3/3] Exporting checkpoint to deployment format ===="
EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+mkdir -p "$(dirname "$EXPORT_PATH")"
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py
index d04d9b1a0..dae1f6637 100644
--- a/modelopt/torch/export/plugins/hf_spec_export.py
+++ b/modelopt/torch/export/plugins/hf_spec_export.py
@@ -13,16 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Modifiy stated_dict and config for exporting speculative decoding in official format."""
+"""Modify state_dict and config for exporting speculative decoding in official format."""
import torch
import torch.nn as nn
+import transformers
from modelopt.torch.speculative.plugins.transformers import HFEagleModel
-SPECULATIVE_DECODING_MODES = ["eagle", "medusa"]
-
-EALGE_MODELOPT_TO_OFFICIAL = {
+EAGLE_MODELOPT_TO_OFFICIAL = {
"required": {
"layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight",
@@ -55,26 +54,31 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
"""Only return the state dict of the draft model in official format and ignore the base model."""
# check the model has only speculative decoding
- opt_modes = model._modelopt_state
- if len(opt_modes) != 1 or opt_modes[0][0] != "eagle":
+ opt_modes = getattr(model, "_modelopt_state", None)
+ if (
+ not isinstance(opt_modes, (list, tuple))
+ or len(opt_modes) != 1
+ or opt_modes[0][0] != "eagle"
+ ):
# if there's other opts, return as is
return post_state_dict
assert isinstance(model, HFEagleModel)
# Check if the state dict keys match
- _check_state_dict_keys_match(model.eagle_module, EALGE_MODELOPT_TO_OFFICIAL["required"])
+ _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
# Convert key names and save the state dict
+ eagle_state = model.eagle_module.state_dict()
export_state_dict = {}
for ours_key, export_key in {
- **EALGE_MODELOPT_TO_OFFICIAL["required"],
- **EALGE_MODELOPT_TO_OFFICIAL["optional"],
+ **EAGLE_MODELOPT_TO_OFFICIAL["required"],
+ **EAGLE_MODELOPT_TO_OFFICIAL["optional"],
}.items():
- if ours_key in model.eagle_module.state_dict():
- export_state_dict[export_key] = model.eagle_module.state_dict()[ours_key]
+ if ours_key in eagle_state:
+ export_state_dict[export_key] = eagle_state[ours_key]
# TODO: (hg) this is a temp fix. Find cleaner way to do this.
- if "eagle_lm_head.weight" not in model.eagle_module.state_dict():
+ if "eagle_lm_head.weight" not in eagle_state:
export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
return export_state_dict
@@ -90,7 +94,7 @@ def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
# This is the config keys in official checkpoint.
template_config = {
- "architectures": ["LlamaForCausalLM"],
+ "architectures": ["LlamaForCausalLMEagle3"],
"bos_token_id": None,
"eos_token_id": None,
"hidden_act": None,
@@ -106,7 +110,7 @@ def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
"rms_norm_eps": None,
"tie_word_embeddings": False,
"torch_dtype": None,
- "transformers_version": None,
+ "transformers_version": transformers.__version__,
"use_cache": None,
"vocab_size": None,
"draft_vocab_size": None,
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index e50092e56..b1b23a12c 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -496,10 +496,8 @@ def _export_hf_checkpoint(
def _quant_applied(hf_quant_config: dict) -> bool:
"""Check if any quantization is applied."""
- return not (
- hf_quant_config["quantization"]["quant_algo"] == QUANTIZATION_NONE
- and not hf_quant_config["quantization"]["quantized_layers"]
- )
+ q = hf_quant_config.get("quantization", {})
+ return not (q.get("quant_algo") == QUANTIZATION_NONE and not q.get("quantized_layers"))
def export_hf_checkpoint(
@@ -521,11 +519,10 @@ def export_hf_checkpoint(
try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
- # When there's no quantization applied, we avoid saving hf_quant_config.json for compatibility
- if _quant_applied(hf_quant_config):
- # Save hf_quant_config.json for backward compatibility
- with open(f"{export_dir}/hf_quant_config.json", "w") as file:
- json.dump(hf_quant_config, file, indent=4)
+ # NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
+ # Save hf_quant_config.json for backward compatibility
+ with open(f"{export_dir}/hf_quant_config.json", "w") as file:
+ json.dump(hf_quant_config, file, indent=4)
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py
index 1e4c3a73d..e1da326d1 100644
--- a/modelopt/torch/speculative/plugins/transformers.py
+++ b/modelopt/torch/speculative/plugins/transformers.py
@@ -699,8 +699,6 @@ def _base_model_forward(
**kwargs,
)
past_key_values = outputs.past_key_values
- if not isinstance(past_key_values, Cache):
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
base_model_hidden_states = outputs.hidden_states[-1]
base_model_logits = outputs.logits
@@ -714,6 +712,7 @@ def _base_model_forward(
# Map the base model logits to the draft vocab
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training:
+ assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized"
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values
@@ -815,6 +814,9 @@ def forward(
)
)
+ if not isinstance(past_key_values, Cache):
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
# ====Run eagle forward====
eagle_loss = None
if self.training:
@@ -853,8 +855,7 @@ def forward(
if not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
- # NOTE: diabled for now.
- # past_key_values.eagle_cache = eagle_cache
+ past_key_values.eagle_cache = eagle_cache
# Compute loss on the eagle modules
regression_loss, classification_loss, accuracy_0 = self._eagle_loss(
@@ -1049,12 +1050,14 @@ def _eagle_loss(
regression_loss = torch.sum(torch.mean(loss_mask * regression_loss, 2)) / (
loss_mask.sum() + 1e-5
)
- base_predict_tok = base_model_logits.argmax(dim=-1)
- eagle_predict_tok = eagle_logits.argmax(dim=-1)
- accuracy = (
- (loss_mask[:, :, 0] * (base_predict_tok == eagle_predict_tok)).float().mean().item()
- )
- accuracy = round(accuracy, 3)
+ # Compute accuracy
+ base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1)
+ eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1)
+ valid = loss_mask[:, :, 0].bool()
+ correct = (base_predict_tok == eagle_predict_tok) & valid
+ denom = valid.sum().clamp_min(1).float()
+ accuracy = round(correct.sum().float().div(denom).item(), 3)
+
return regression_loss, classification_loss, accuracy
@torch.no_grad()
From ef431df30758d11682f6d5e3ca10b8e411f01891 Mon Sep 17 00:00:00 2001
From: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Date: Sat, 6 Sep 2025 01:50:13 +0000
Subject: [PATCH 4/7] polish: install instruction in example; import paths;
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
---
examples/speculative_decoding/README.md | 2 +-
modelopt/torch/export/unified_export_hf.py | 5 +----
2 files changed, 2 insertions(+), 5 deletions(-)
diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md
index 9a3d11f3a..698a69123 100644
--- a/examples/speculative_decoding/README.md
+++ b/examples/speculative_decoding/README.md
@@ -28,7 +28,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,
Install Modelopt with `hf` dependencies and other requirements for this example:
```bash
-pip install nvidia-modelopt[hf]
+pip install -e ...
pip install -r requirements.txt
```
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index b1b23a12c..379ade24d 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -27,10 +27,6 @@
import torch
import torch.nn as nn
-from modelopt.torch.export.plugins import (
- rename_and_prune_if_spec_decoding,
- set_config_if_spec_decoding,
-)
from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import NVFP4QTensor
@@ -57,6 +53,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
+from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding
from .quant_utils import (
fuse_prequant_layernorm,
get_activation_scaling_factor,
From 04bed24a5f404154b32df80f5c9af1632dfe325f Mon Sep 17 00:00:00 2001
From: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Date: Sat, 6 Sep 2025 03:23:25 +0000
Subject: [PATCH 5/7] remove export dependency on transformers
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
---
modelopt/torch/export/plugins/__init__.py | 4 ++--
modelopt/torch/export/plugins/hf_spec_export.py | 8 +-------
2 files changed, 3 insertions(+), 9 deletions(-)
diff --git a/modelopt/torch/export/plugins/__init__.py b/modelopt/torch/export/plugins/__init__.py
index 8f90fc528..0c5766f49 100644
--- a/modelopt/torch/export/plugins/__init__.py
+++ b/modelopt/torch/export/plugins/__init__.py
@@ -19,5 +19,5 @@
with import_plugin("megatron_importer"):
from .megatron_importer import *
-with import_plugin("transformers"):
- from .hf_spec_export import *
+
+from .hf_spec_export import *
diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py
index dae1f6637..0a5045f06 100644
--- a/modelopt/torch/export/plugins/hf_spec_export.py
+++ b/modelopt/torch/export/plugins/hf_spec_export.py
@@ -17,9 +17,6 @@
import torch
import torch.nn as nn
-import transformers
-
-from modelopt.torch.speculative.plugins.transformers import HFEagleModel
EAGLE_MODELOPT_TO_OFFICIAL = {
"required": {
@@ -63,7 +60,6 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
# if there's other opts, return as is
return post_state_dict
- assert isinstance(model, HFEagleModel)
# Check if the state dict keys match
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
@@ -90,8 +86,6 @@ def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
# return as is
return config_data
- assert isinstance(model, HFEagleModel)
-
# This is the config keys in official checkpoint.
template_config = {
"architectures": ["LlamaForCausalLMEagle3"],
@@ -110,7 +104,7 @@ def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
"rms_norm_eps": None,
"tie_word_embeddings": False,
"torch_dtype": None,
- "transformers_version": transformers.__version__,
+ "transformers_version": None,
"use_cache": None,
"vocab_size": None,
"draft_vocab_size": None,
From 6d897ec3403a201151fc0dddeb68dc029f30fa31 Mon Sep 17 00:00:00 2001
From: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Date: Mon, 8 Sep 2025 17:22:51 +0000
Subject: [PATCH 6/7] address review comments
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
---
examples/speculative_decoding/README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md
index 698a69123..2e8966d29 100644
--- a/examples/speculative_decoding/README.md
+++ b/examples/speculative_decoding/README.md
@@ -86,7 +86,7 @@ We can optionally use smaller vocab size for the draft model for faster training
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
```
-This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later.
+This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.
### (Optional) Configuring Draft Model
From 854564eab3927099e7d3503b5ef024a9127b7a12 Mon Sep 17 00:00:00 2001
From: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Date: Mon, 8 Sep 2025 21:28:22 +0000
Subject: [PATCH 7/7] address comments: remove irrelevant code
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
---
modelopt/torch/export/unified_export_hf.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index 379ade24d..f514e660d 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -491,12 +491,6 @@ def _export_hf_checkpoint(
return quantized_state_dict, quant_config
-def _quant_applied(hf_quant_config: dict) -> bool:
- """Check if any quantization is applied."""
- q = hf_quant_config.get("quantization", {})
- return not (q.get("quant_algo") == QUANTIZATION_NONE and not q.get("quantized_layers"))
-
-
def export_hf_checkpoint(
model: nn.Module,
dtype: torch.dtype | None = None,