diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md
index 187bda560..2e8966d29 100644
--- a/examples/speculative_decoding/README.md
+++ b/examples/speculative_decoding/README.md
@@ -1,136 +1,145 @@
# Speculative Decoding
-Large Language Models (LLMs) have demonstrated remarkable capabilities and are increasingly applied in various domains. However, their text generation process is costly and slow. This inefficiency is attributed to the nature of auto-regressive decoding: each token generation necessitates a forward pass, requiring access to the entire parameter set of the LLM. This results in a memory-bound limitation for auto-regressive decoding. To accelerate auto-regressive decoding, speculative decoding methods use a draft model (either a smaller model or the LLM itself) to guess the next γ tokens through standard auto-regressive generation. Subsequently, the original LLM validates these guessed tokens, necessitating only a single forward pass for verification. If the draft model accurately predicts α tokens, a single forward pass of the original LLM can generate α+1 tokens.
+[](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/5_speculative_decoding.html)
-This section focuses on the end-to-end workflow of training speculative decoding modules to deploy for your model.
+Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving generation speed.
-In this example, the end-to-end workflow of speculative decoding is demonstrated for a pretrained HF text generation model.
+This folder contains an end-to-end runnable speculative decoding fine‑tuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the Daring‑Anteater dataset.
+
+This example focuses on training with Hugging Face. To train with Megatron‑LM, see the [Megatron‑LM example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt).
+
+## Contents
-| **Section** | **Description** | **Link** | **Docs** |
-| :------------: | :------------: | :------------: | :------------: |
-| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
-| Getting Started | Learn how to optimize your models using PTQ to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/5_speculative_decoding.html)\] |
-| Support Matrix | View the support matrix to see speculation technique support | \[[Link](#support-matrix)\] | |
-| End to End | Example scripts demonstrating how to train speculation modules using Hugging Face / NeMo / Megatron-LM models | \[[Link](#end-to-end-speculative-decoding-examples)\] | |
-| Deployment | Next steps after speculation module is trained | \[[Link](#deployment)\] | |
-| Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] | \[[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/1_quantization.html)\] |
-| Resources | Extra links to relevant resources | \[[Link](#resources)\] | |
+| **Section** | **Description** | **Jump To** |
+| :------------: | :------------: | :------------: |
+| Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] |
+| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
+| Complete Workflow | Full example with configurable training pipeline | \[[Link](#complete-workflow)\] |
+| Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] |
+| Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] |
+| Resources | Extra links to relevant resources | \[[Link](#resources)\] |
## Pre-Requisites
-### HF
-
-Install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example:
+Install Modelopt with `hf` dependencies and other requirements for this example:
```bash
-pip install nvidia-modelopt[hf]
+pip install -e ...
pip install -r requirements.txt
```
-### NeMo / Megatron-LM
+We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Download by:
-Use the NeMo container `nvcr.io/nvidia/nemo:25.07` or later which has all the dependencies installed.
+```bash
+git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
+```
-## Getting Started
+## Getting Started: Simplified Workflow
-### Prepare Data
+```bash
+bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
+```
-In speculative decoding fine-tuning, extra speculative decoding module, like Medusa heads or EAGLE module, are added to the base model to predict the next γ tokens. These tokens will then be validated by the original LLM. In order for these predicted tokens to be accepted by the original LLM, their prediction distributions should be similar to that of the base model. Therefore, we need to prepare fine-tuning data generated from the original LLM. Start by launching an inference server that will run the base model. Let us use TinyLlama/TinyLlama-1.1B-Chat-v1.0 as an example.
+This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
-First, set up a vllm server with TinyLlama. Make sure to use a different docker container other than the one for training as installing vllm may cause version conflicts with modelopt. Note: for quantized models by ModelOpt, you need to add --quantization=modelopt flag.
+- Initializes the draft model with [default settings](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18)
+- Fine-tunes the model on the [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset
+- Evaluates the acceptance rate on [MT-Bench](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)
+- Exports a checkpoint ready for deployment
-```sh
-pip install vllm
-vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --api-key token-abc123 --port 8000 --tensor-parallel-size 1
-```
+## Complete Workflow
-Then, we adapt the fine-tuning data by calling this server. In this example, we use Daring-Anteater dataset.
+This section presents a more comprehensive example for customizing speculative decoding training with Modelopt, including optional steps to enhance training quality and efficiency.
-```sh
-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
-python3 server_generate.py --data_path Daring-Anteater/train.jsonl --output_path finetune/data.jsonl --max_token 512 --chat
-```
+### (Optional) Data Synthesis
-To add a system prompt, use the `--system_prompt` argument:
+To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data, ensuring that the draft model’s output distribution closely aligns with that of the base model.
-```sh
-python3 server_generate.py --data_path Daring-Anteater/train.jsonl --output_path finetune/data.jsonl --max_token 512 --chat --system_prompt
-```
+To prepare such data, we launch an inference server with the base model:
-#### SLURM Prepare Data
+```bash
+pip install vllm
+vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000 --tensor-parallel-size 1
+```
-For basic parallelization of synthetic data generation we provide some SLURM support.
-Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.
+Note: Add `--quantization=modelopt` flag for quantized models.
-Example of allocating 4 nodes for 120 minutes
+Then, we generate conversations with base model and prompts from Daring-Anteater:
-```sh
-salloc -N4 -A -p -J -synthetic:data-gen -t 120
+```bash
+python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl
```
-Create shards of some given size
+To add a system prompt, use the `--system_prompt ` argument.
-```sh
-python3 distributed_generate/sharding_utils.py --input_path /data/train.jsonl --output_dir /data/train/ --max_lines_per_shard 10000
-```
+For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support.
-Run workers on SLURM
+### (Optional) Draft Vocabulary Compression
-```sh
-bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 0 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
+We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
+
+```bash
+python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
```
-`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
-This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
-In this case, the first 40 shards of data will be processed.
-To process the next 40 shards
+This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.
-```sh
-bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
-```
+### (Optional) Configuring Draft Model
-To combine the shards back
+For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`:
-```sh
-python3 distributed_generate/sharding_utils.py --input_dir /data/output/ --output_path /data/output.jsonl --combine
+```json
+{
+ "draft_vocab_size": 32000
+}
```
-### Speculative Decoding Example Training Workflow
+### Training Draft Model with Modelopt
-Here is the recommended end-to-end speculative decoding training workflow:
+`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps:
+First, load the base model and tokenizer from Hugging Face:
```python
-import os
-import torch
-import transformers
-import modelopt.torch.opt as mto
-import modelopt.torch.speculative as mtsp
-
-# Create a base model
model = transformers.AutoModelForCausalLM.from_pretrained(
- "",
+ ""
)
+```
-if mode == "medusa":
- config = {
- "medusa_num_heads": 2,
- "medusa_num_layers": 1,
- }
-elif mode == "eagle":
- config = {
- "eagle_num_layers": 1,
- "use_input_layernorm_in_first_layer": True,
- "use_last_layernorm": False
+Then, load default eagle config and make necessary overwrites:
+
+```python
+# Load default config
+config = {
+ "eagle1": EAGLE1_DEFAULT_CFG,
+ "eagle3": EAGLE3_DEFAULT_CFG,
+}[training_args.mode]["config"]
+
+# overwrite config with custom config
+config["eagle_architecture_config"].update({"": ""})
+
+# Mandatory: hidden size, vocab size and max position embeddings must match base model
+config["eagle_architecture_config"].update(
+ {
+ "hidden_size": model.config.hidden_size,
+ "vocab_size": model.config.vocab_size,
+ "max_position_embeddings": model.config.max_position_embeddings,
}
-mtsp.convert(model, [(mode, config)])
+)
+```
+
+Then, we convert model to a speculative decoding model:
+
+```python
+mtsp.convert(model, [("eagle", config)])
+```
-tokenizer = transformers.AutoTokenizer.from_pretrained(ckpt_path)
-tokenizer.pad_token_id = tokenizer.eos_token_id
+This will modify the model in-place with eagle training forward, making it compatible with HF trainer:
+```python
# Create a trainer
trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer._move_model_to_device(model, trainer.args.device)
@@ -143,86 +152,92 @@ trainer.save_state()
trainer.save_model("")
```
-## Support Matrix
+We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by:
-### Supported Models/Techniques
+```bash
+./launch_train.sh --model $BASE_MODEL \
+ --output_dir $OUTPUT_DIR \
+ --data $DATA \
+ --num_gpu $NUM_GPU \
+ --num_epochs 10 \
+ --eagle_config eagle_config.json #This is where we optionally overwrite default eagle configs
+```
-#### NeMo/Megatron-LM
+The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
-| Model | Medusa | EAGLE1/2 | EAGLE3 |
-| :---: | :---: | :---: | :---: |
-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
-| Mistral | ✅ | ✅ | ✅ |
-| Phi 3 | ✅ | ✅ | ✅ |
-| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
+### Model Validation
-#### Hugging Face
+After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
-| Model | Medusa | EAGLE1/2 | EAGLE3 |
-| :---: | :---: | :---: | :---: |
-| LLAMA 2 | ✅ | ✅ | ✅ |
-| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
-| Mistral | ✅ | ✅ | ✅ |
-| Phi 3 | ✅ | ✅ | ✅ |
-| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
+```bash
+python ar_validate.py --model_path $OUTPUT_DIR
+```
-### End-to-end Speculative Decoding Examples
+Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
-### MLM Example
+### Export
-
+```bash
+python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
+```
-
+### Deployment
-### HuggingFace
+The exported checkpoint can be deployed on TRT-LLM or SGLang.
-This folder contains end-to-end runnable speculative decoding fine-tuning pipeline where TinyLlama from huggingface is trained on Daring-Anteater dataset.
+#### TRT-LLM
-First, download the data:
+To serve the checkpoint with trtllm, run trtllm-serve with:
-```sh
-git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
+```bash
+trtllm-serve --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml
```
-Then, prepare the synthesized data from the base model. Please refer to the **Prepare Data** section.
+, with `extra-llm-api-config.yml` being
+
+```yaml
+enable_attention_dp: false
+disable_overlap_scheduler: true
+enable_autotuner: false
-Next, we fine-tune the speculative decoding models with the base model frozen. Here is the command for Medusa and EAGLE:
+cuda_graph_config:
+ max_batch_size: 1
-```sh
-./launch.sh --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
- --data finetune/data.jsonl \
- --mode medusa \
- --num_epochs 1 --lr 1e-5 --save_steps 1000 \
- --output_dir medusa-tinyllama \
- --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
- --num_gpu 1 \
- --medusa_num_heads 2 --medusa_num_layers 1
+speculative_config:
+ decoding_type: Eagle
+ max_draft_len: 3
+ speculative_model_dir:
-./launch.sh --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
- --data finetune/data.jsonl \
- --mode eagle \
- --num_epochs 1 --lr 1e-5 --save_steps 1000 \
- --output_dir eagle-tinyllama \
- --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
- --num_gpu 1 \
- --eagle_num_layers 1
+kv_cache_config:
+ enable_block_reuse: false
```
-This will generate fine-tuned checkpoints in `output_dir` specified above.
+Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage.
-Alternatively, you can refer to this [notebook](example.ipynb).
+#### SGLang
-### Deployment
+Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage.
+
+#### Deploying Quantized model
-The final model after end-to-end speculative decoding fine-tuning is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. It can be deployed to TensorRT-LLM (TRTLLM) or to TensorRT just like a regular **ModelOpt** model. See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).
+See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).
+
+## Support Matrix
+
+| Model | Medusa | EAGLE1/2 | EAGLE3 |
+| :---: | :---: | :---: | :---: |
+| LLAMA 2 | ✅ | ✅ | ✅ |
+| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
+| Mistral | ✅ | ✅ | ✅ |
+| Phi 3 | ✅ | ✅ | ✅ |
+| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
## Speculation Module Checkpoints
-Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
-Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang)!\
+Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - NVIDIA TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
+Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and [SGLang](https://github.com/sgl-project/sglang)!\
More models coming soon!
## Resources
diff --git a/examples/speculative_decoding/SLURM_prepare_data.md b/examples/speculative_decoding/SLURM_prepare_data.md
new file mode 100644
index 000000000..bc3c0abb6
--- /dev/null
+++ b/examples/speculative_decoding/SLURM_prepare_data.md
@@ -0,0 +1,37 @@
+# SLURM Prepare Data
+
+For basic parallelization of synthetic data generation we provide some SLURM support.
+Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.
+
+Example of allocating 4 nodes for 120 minutes
+
+```sh
+salloc -N4 -A -p -J -synthetic:data-gen -t 120
+```
+
+Create shards of some given size
+
+```sh
+python3 distributed_generate/sharding_utils.py --input_path /data/train.jsonl --output_dir /data/train/ --max_lines_per_shard 10000
+```
+
+Run workers on SLURM
+
+```sh
+bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 0 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
+```
+
+`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
+This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
+In this case, the first 40 shards of data will be processed.
+To process the next 40 shards
+
+```sh
+bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
+```
+
+To combine the shards back
+
+```sh
+python3 distributed_generate/sharding_utils.py --input_dir /data/output/ --output_path /data/output.jsonl --combine
+```
diff --git a/examples/speculative_decoding/ar_validate.py b/examples/speculative_decoding/ar_validate.py
index bfbcb2239..38b886693 100644
--- a/examples/speculative_decoding/ar_validate.py
+++ b/examples/speculative_decoding/ar_validate.py
@@ -26,7 +26,7 @@
mto.enable_huggingface_checkpointing()
-def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None):
+def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
validator = HFARValidation(model, tokenizer)
num_samples = min(num_samples, len(ds))
ars = []
@@ -54,12 +54,12 @@ def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=No
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
- parser.add_argument("--steps", type=int, default=1, help="Steps for AR validation")
+ parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation")
parser.add_argument(
- "--osl", type=int, default=100, help="Output sequence length for AR validation"
+ "--osl", type=int, default=32, help="Output sequence length for AR validation"
)
parser.add_argument(
- "--num_samples", type=int, default=20, help="Number of MT-Bench samples to use"
+ "--num_samples", type=int, default=80, help="Number of MT-Bench samples to use"
)
parser.add_argument(
"--ar_lower_bound",
diff --git a/examples/speculative_decoding/calibrate_draft_vocab.py b/examples/speculative_decoding/calibrate_draft_vocab.py
index 1211d4eb6..37a798cf8 100644
--- a/examples/speculative_decoding/calibrate_draft_vocab.py
+++ b/examples/speculative_decoding/calibrate_draft_vocab.py
@@ -28,11 +28,10 @@ def main():
parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer")
parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)")
parser.add_argument(
- "--eagle_config",
- type=str,
+ "--draft_vocab_size",
+ type=int,
required=True,
- default="eagle_config.json",
- help="Path to eagle_config.json",
+ help="Draft vocab size",
)
parser.add_argument(
"--calibrate_size",
@@ -45,12 +44,6 @@ def main():
)
args = parser.parse_args()
- with open(args.eagle_config) as f:
- eagle_config = json.load(f)
- if "draft_vocab_size" not in eagle_config:
- print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.")
- return
-
print("Calibrating vocab...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
with open(args.data) as f:
@@ -59,7 +52,7 @@ def main():
conversations = conversations[: args.calibrate_size]
conversations = [item for sublist in conversations for item in sublist]
- d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"])
+ d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size)
model_name = os.path.basename(os.path.normpath(args.model))
vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt")
os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
diff --git a/examples/speculative_decoding/eagle_config.json b/examples/speculative_decoding/eagle_config.json
index 55ff948ed..b4b218fdf 100644
--- a/examples/speculative_decoding/eagle_config.json
+++ b/examples/speculative_decoding/eagle_config.json
@@ -1,3 +1,10 @@
{
- "draft_vocab_size": 32000
+ "rope_scaling": {
+ "factor": 32.0,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "initializer_range": 0.02
}
diff --git a/examples/speculative_decoding/export_hf_checkpoint.py b/examples/speculative_decoding/export_hf_checkpoint.py
new file mode 100644
index 000000000..dfc293ee9
--- /dev/null
+++ b/examples/speculative_decoding/export_hf_checkpoint.py
@@ -0,0 +1,48 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Export a HF checkpoint (with ModelOpt state) for deployment."""
+
+import argparse
+
+import torch
+from transformers import AutoModelForCausalLM
+
+import modelopt.torch.opt as mto
+from modelopt.torch.export import export_hf_checkpoint
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Export a HF checkpoint (with ModelOpt state) for deployment."
+ )
+ parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
+ parser.add_argument(
+ "--export_path", type=str, default="Destination directory for exported files."
+ )
+ return parser.parse_args()
+
+
+mto.enable_huggingface_checkpointing()
+
+args = parse_args()
+model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
+model.eval()
+with torch.inference_mode():
+ export_hf_checkpoint(
+ model, # The quantized model.
+ export_dir=args.export_path, # The directory where the exported files will be stored.
+ )
+print(f"Exported checkpoint to {args.export_path}")
diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh
new file mode 100755
index 000000000..cf54f9446
--- /dev/null
+++ b/examples/speculative_decoding/launch_train.sh
@@ -0,0 +1,157 @@
+#!/bin/bash
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+while [ $# -gt 0 ]; do
+ case "$1" in
+ --training_seq_len*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ TRAINING_SEQ_LEN="${1#*=}"
+ ;;
+ --model*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MODEL="${1#*=}"
+ ;;
+ --data*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ DATA="${1#*=}"
+ ;;
+ --mode*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MODE="${1#*=}"
+ ;;
+ --output_dir*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ OUTPUT_DIR="${1#*=}"
+ ;;
+ --num_epochs*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ NUM_EPOCHS="${1#*=}"
+ ;;
+ --save_steps*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ SAVE_STEPS="${1#*=}"
+ ;;
+ --lr*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ LR="${1#*=}"
+ ;;
+ --train_bs*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ TRAIN_BS="${1#*=}"
+ ;;
+ --medusa_num_heads*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MEDUSA_NUM_HEADS="${1#*=}"
+ ;;
+ --medusa_num_layers*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ MEDUSA_NUM_LAYERS="${1#*=}"
+ ;;
+ --eagle_config*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ EAGLE_CONFIG="${1#*=}"
+ ;;
+ --fsdp_transformer_layer_cls_to_wrap*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
+ ;;
+ --num_gpu*)
+ if [[ "$1" != *=* ]]; then shift; fi
+ NUM_GPU="${1#*=}"
+ ;;
+ *)
+ >&2 printf "Error: Invalid argument ${1#*=}\n"
+ exit 1
+ ;;
+ esac
+ shift
+done
+
+set -x
+
+# Get the default value for save_steps based on the available number of GPUs
+GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
+# Calculate save_steps
+DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+
+MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
+MODE=${MODE:-"eagle3"}
+# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path
+MODEL_BASENAME=$(basename "$MODEL")
+OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
+NUM_EPOCHS=${NUM_EPOCHS:-1}
+SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
+LR=${LR:-"1e-4"}
+TRAIN_BS=${TRAIN_BS:-4}
+MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
+MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
+REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
+REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
+FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
+NUM_GPU=${NUM_GPU:-1}
+TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512}
+
+if [[ "$MODE" == "medusa" ]]; then
+ SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
+elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then
+ if [[ -n "$EAGLE_CONFIG" ]]; then
+ SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
+ else
+ SPECULATIVE_ARGS=""
+ fi
+else
+ echo "Only medusa, eagle1, eagle3 supported for now!"
+ exit 1
+fi
+
+if [[ "$NUM_GPU" == 1 ]]; then
+ MULTI_GPU=""
+else
+ MULTI_GPU="--multi_gpu"
+fi
+
+# Disable tokenizers parallelism to avoid warning
+export TOKENIZERS_PARALLELISM=False
+CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+ --mode $MODE \
+ --model_name_or_path $MODEL \
+ --training_seq_len $TRAINING_SEQ_LEN \
+ --dataloader_drop_last True \
+ --bf16 True \
+ --output_dir $OUTPUT_DIR \
+ --num_train_epochs $NUM_EPOCHS \
+ --per_device_train_batch_size $TRAIN_BS \
+ --per_device_eval_batch_size $TRAIN_BS \
+ --gradient_accumulation_steps 1 \
+ --do_eval False \
+ --eval_accumulation_steps 1 \
+ --save_strategy steps \
+ --save_steps $SAVE_STEPS \
+ --learning_rate $LR \
+ --weight_decay 0.0 \
+ --warmup_steps 100 \
+ --lr_scheduler_type linear \
+ --logging_steps 100 \
+ --tf32 True \
+ --data_path $DATA \
+ $SPECULATIVE_ARGS
+"
+
+start_time=$(date +%s)
+sh -c "$CMD"
+echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py
index 28e388b4d..aae4177c6 100644
--- a/examples/speculative_decoding/main.py
+++ b/examples/speculative_decoding/main.py
@@ -47,6 +47,13 @@
import modelopt.torch.speculative as mtsp
from modelopt.torch.utils import print_rank_0
+try:
+ import wandb
+
+ wandb.init()
+except ImportError:
+ wandb = None
+
torch.manual_seed(0)
mto.enable_huggingface_checkpointing()
@@ -170,6 +177,8 @@ def train():
{
"hidden_size": model.config.hidden_size,
"vocab_size": model.config.vocab_size,
+ # we also overwrite max_pos_embedding for deployment compatibility
+ "max_position_embeddings": model.config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else model.config.vocab_size,
@@ -213,6 +222,8 @@ def on_step_end(self, args, state, control, **kwargs):
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
+ if wandb:
+ wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
return control
trainer = Trainer(
diff --git a/examples/speculative_decoding/server_generate.py b/examples/speculative_decoding/server_generate.py
index 4541ecf42..0fb71a0a0 100644
--- a/examples/speculative_decoding/server_generate.py
+++ b/examples/speculative_decoding/server_generate.py
@@ -46,7 +46,7 @@
parser.add_argument(
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
)
-parser.add_argument("--chat", action="store_true", help="Use chat mode")
+parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
parser.add_argument("--model", type=str, default="model", help="Model name")
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")
diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh
new file mode 100644
index 000000000..042a27652
--- /dev/null
+++ b/examples/speculative_decoding/train_eagle3_and_export.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -eo pipefail
+
+# Set default values for BASE_MODEL, NUM_GPU, and DATA
+BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
+NUM_GPU=1
+DATA=Daring-Anteater/train.jsonl
+
+# Parse input arguments --base_model, --num_gpu, and --data
+while [[ $# -gt 0 ]]; do
+ key="$1"
+ case $key in
+ --base_model)
+ BASE_MODEL="$2"
+ shift; shift
+ ;;
+ --num_gpu)
+ NUM_GPU="$2"
+ shift; shift
+ ;;
+ --data)
+ DATA="$2"
+ shift; shift
+ ;;
+ *)
+ echo "Unknown argument: $1"
+ exit 1
+ ;;
+ esac
+done
+
+
+if [[ "$NUM_GPU" == 1 ]]; then
+ export CUDA_VISIBLE_DEVICES=0
+else
+ # Export as 0,1,...,N-1 for NUM_GPU GPUs
+ devs="$(seq -s, 0 $((NUM_GPU-1)))"
+ export CUDA_VISIBLE_DEVICES="$devs"
+fi
+
+MODEL_BASENAME=$(basename "$BASE_MODEL")
+
+echo "==== [1/3] Training draft model ===="
+OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+mkdir -p "$(dirname "$OUTPUT_DIR")"
+./launch_train.sh --model $BASE_MODEL \
+ --output_dir $OUTPUT_DIR \
+ --data $DATA \
+ --num_gpu $NUM_GPU \
+ --num_epochs 2 \
+ --eagle_config eagle_config.json
+
+echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ===="
+python ar_validate.py --model_path $OUTPUT_DIR
+
+echo "==== [3/3] Exporting checkpoint to deployment format ===="
+EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
+mkdir -p "$(dirname "$EXPORT_PATH")"
+python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
diff --git a/modelopt/torch/export/plugins/__init__.py b/modelopt/torch/export/plugins/__init__.py
index a7bbc3fb3..0c5766f49 100644
--- a/modelopt/torch/export/plugins/__init__.py
+++ b/modelopt/torch/export/plugins/__init__.py
@@ -19,3 +19,5 @@
with import_plugin("megatron_importer"):
from .megatron_importer import *
+
+from .hf_spec_export import *
diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py
new file mode 100644
index 000000000..0a5045f06
--- /dev/null
+++ b/modelopt/torch/export/plugins/hf_spec_export.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Modify state_dict and config for exporting speculative decoding in official format."""
+
+import torch
+import torch.nn as nn
+
+EAGLE_MODELOPT_TO_OFFICIAL = {
+ "required": {
+ "layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight",
+ "layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight",
+ "layers.0.self_attn.v_proj.weight": "midlayer.self_attn.v_proj.weight",
+ "layers.0.self_attn.o_proj.weight": "midlayer.self_attn.o_proj.weight",
+ "layers.0.mlp.gate_proj.weight": "midlayer.mlp.gate_proj.weight",
+ "layers.0.mlp.up_proj.weight": "midlayer.mlp.up_proj.weight",
+ "layers.0.mlp.down_proj.weight": "midlayer.mlp.down_proj.weight",
+ "hidden_norm.weight": "midlayer.hidden_norm.weight",
+ "input_embeds_norm.weight": "midlayer.input_layernorm.weight",
+ "layers.0.post_attention_layernorm.weight": "midlayer.post_attention_layernorm.weight",
+ "norm.weight": "norm.weight",
+ "fc.weight": "fc.weight",
+ },
+ "optional": {
+ "d2t": "d2t",
+ "eagle_lm_head.weight": "lm_head.weight",
+ },
+}
+
+
+def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
+ """Check if the state dict keys match."""
+ draft_keys = set(draft_model.state_dict().keys())
+ for required_key in required_items:
+ if required_key not in draft_keys:
+ raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")
+
+
+def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
+ """Only return the state dict of the draft model in official format and ignore the base model."""
+ # check the model has only speculative decoding
+ opt_modes = getattr(model, "_modelopt_state", None)
+ if (
+ not isinstance(opt_modes, (list, tuple))
+ or len(opt_modes) != 1
+ or opt_modes[0][0] != "eagle"
+ ):
+ # if there's other opts, return as is
+ return post_state_dict
+
+ # Check if the state dict keys match
+ _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
+
+ # Convert key names and save the state dict
+ eagle_state = model.eagle_module.state_dict()
+ export_state_dict = {}
+ for ours_key, export_key in {
+ **EAGLE_MODELOPT_TO_OFFICIAL["required"],
+ **EAGLE_MODELOPT_TO_OFFICIAL["optional"],
+ }.items():
+ if ours_key in eagle_state:
+ export_state_dict[export_key] = eagle_state[ours_key]
+
+ # TODO: (hg) this is a temp fix. Find cleaner way to do this.
+ if "eagle_lm_head.weight" not in eagle_state:
+ export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
+
+ return export_state_dict
+
+
+def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
+ """Return the config of draft model in official format."""
+ if len(model._modelopt_state) != 1 or model._modelopt_state[0][0] != "eagle":
+ # return as is
+ return config_data
+
+ # This is the config keys in official checkpoint.
+ template_config = {
+ "architectures": ["LlamaForCausalLMEagle3"],
+ "bos_token_id": None,
+ "eos_token_id": None,
+ "hidden_act": None,
+ "hidden_size": None,
+ "initializer_range": None,
+ "intermediate_size": None,
+ "max_position_embeddings": None,
+ "model_type": "llama",
+ "num_attention_heads": None,
+ "num_key_value_heads": None,
+ "num_hidden_layers": None,
+ "pad_token_id": None,
+ "rms_norm_eps": None,
+ "tie_word_embeddings": False,
+ "torch_dtype": None,
+ "transformers_version": None,
+ "use_cache": None,
+ "vocab_size": None,
+ "draft_vocab_size": None,
+ "rope_scaling": None,
+ "attention_bias": None,
+ "attention_dropout": None,
+ "head_dim": None,
+ "mlp_bias": None,
+ "pretraining_tp": None,
+ "rope_theta": None,
+ "eagle_config": {
+ "eagle_aux_hidden_state_layer_ids": None,
+ "use_aux_hidden_state": None,
+ "use_input_layernorm_in_first_layer": None,
+ "use_last_layernorm": None,
+ "use_mtp_layernorm": None,
+ },
+ }
+
+ def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
+ if getattr(model.eagle_config, key, None) is not None:
+ return getattr(model.eagle_config, key)
+ elif getattr(model.config, key, None) is not None:
+ return getattr(model.config, key)
+ else:
+ return None
+
+ for key in template_config:
+ value = template_config[key]
+ if isinstance(value, dict):
+ # for eagle config, we find it in model.eagle_config
+ for sub_key in value:
+ value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model)
+ elif value is None:
+ # First, we try to load fron eagle config.
+ new_value = _get_config_from_eagle_config_or_base_config(key, model)
+ # If the value is a torch.dtype, we convert to string for serialization.
+ if isinstance(new_value, torch.dtype):
+ new_value = str(new_value).replace("torch.", "")
+ template_config[key] = new_value
+
+ return template_config
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
old mode 100755
new mode 100644
index b18ae2619..f514e660d
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -53,6 +53,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
+from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding
from .quant_utils import (
fuse_prequant_layernorm,
get_activation_scaling_factor,
@@ -509,12 +510,15 @@ def export_hf_checkpoint(
try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
+ # NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
+ post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
+
# Save model
model.save_pretrained(
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
@@ -528,6 +532,8 @@ def export_hf_checkpoint(
config_data["quantization_config"] = hf_quant_config
+ config_data = set_config_if_spec_decoding(model, config_data)
+
with open(original_config, "w") as file:
json.dump(config_data, file, indent=4)
diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py
index ff425d63f..e1da326d1 100644
--- a/modelopt/torch/speculative/plugins/transformers.py
+++ b/modelopt/torch/speculative/plugins/transformers.py
@@ -182,17 +182,6 @@ def __init__(self, config, decoder_layer_cls, bias=False):
"""Init function for EagleModule."""
super().__init__()
self.config = config
-
- # NOTE:This is a temporary fix to support Qwen and Mixtral in current release.
- # This is refactored in following MR.
- config_overwrite = {
- "mlp_bias": False,
- "attention_bias": False,
- "head_dim": self.config.hidden_size // self.config.num_attention_heads,
- }
- for key, value in config_overwrite.items():
- setattr(self.config, key, value)
-
self.layers = nn.ModuleList(
[decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@@ -696,8 +685,10 @@ def _base_model_forward(
past_key_values,
freeze_base_model,
labels,
- kwargs,
+ **kwargs,
):
+ # TODO: This function still use eagle_module. Ideally we should remove it,
+ # so we can del model.eagle_module on the base model ranks to save memory.
with torch.no_grad() if freeze_base_model else contextlib.nullcontext():
outputs = super().forward(
input_ids=input_ids,
@@ -708,8 +699,6 @@ def _base_model_forward(
**kwargs,
)
past_key_values = outputs.past_key_values
- if not isinstance(past_key_values, Cache):
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
base_model_hidden_states = outputs.hidden_states[-1]
base_model_logits = outputs.logits
@@ -723,14 +712,18 @@ def _base_model_forward(
# Map the base model logits to the draft vocab
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training:
- reverse_mapping = (
- torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
- + self.eagle_module.d2t
- )
- base_model_logits = base_model_logits[:, :, reverse_mapping]
+ assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized"
+ base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values
+ def _map_logits_to_draft_vocab(self, full_logits):
+ reverse_mapping = (
+ torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device)
+ + self.eagle_module.d2t
+ )
+ return full_logits[:, :, reverse_mapping]
+
def _eagle_forward(
self,
eagle_input_hidden_states,
@@ -795,17 +788,34 @@ def forward(
loss_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device)
# ====First, we run base model forward====
- base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = (
- self._base_model_forward(
- input_ids,
- attention_mask,
- position_ids,
- past_key_values,
- self.eagle_freeze_base_model,
- labels,
- kwargs,
+ if "base_model_outputs" in kwargs:
+ # Parse base model outputs forwarded from teacher
+ base_outputs = kwargs["base_model_outputs"]
+ base_model_hidden_states = base_outputs["base_model_hidden_states"]
+ if "base_model_logits" in base_outputs:
+ base_model_logits = base_outputs["base_model_logits"]
+ else:
+ base_model_logits = self.lm_head(base_model_hidden_states)
+ if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
+ base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
+ base_model_loss = None
+ past_key_values = None
+
+ else:
+ base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = (
+ self._base_model_forward(
+ input_ids,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ self.eagle_freeze_base_model,
+ labels,
+ **kwargs,
+ )
)
- )
+
+ if not isinstance(past_key_values, Cache):
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# ====Run eagle forward====
eagle_loss = None
@@ -813,9 +823,11 @@ def forward(
# In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers
batch_size, seq_length, _ = base_model_hidden_states.shape
if self.eagle_config.use_aux_hidden_state:
- eagle_input_hidden_states = self.eagle_module.fc(
- torch.cat(self.pop_aux_hidden_states(), dim=-1)
- )
+ if "base_model_outputs" in kwargs:
+ aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"]
+ else:
+ aux_hidden_states = torch.cat(self.pop_aux_hidden_states(), dim=-1)
+ eagle_input_hidden_states = self.eagle_module.fc(aux_hidden_states)
else:
eagle_input_hidden_states = base_model_hidden_states
@@ -842,10 +854,11 @@ def forward(
if not isinstance(eagle_cache, Cache):
eagle_cache = DynamicCache.from_legacy_cache(eagle_cache)
+
past_key_values.eagle_cache = eagle_cache
# Compute loss on the eagle modules
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_0 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[:, :-1],
@@ -879,7 +892,7 @@ def forward(
position_embeddings,
)
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_1 = self._eagle_loss(
# base model predict +1 tok, while eagle predict +2
# so we shift base model outputs compared to eagle outputs
base_model_hidden_states[:, 1:],
@@ -927,7 +940,7 @@ def forward(
position_embeddings,
)
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_2 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[:, -seq_length:-1, :],
@@ -969,7 +982,7 @@ def forward(
position_embeddings,
)
- regression_loss, classification_loss = self._eagle_loss(
+ regression_loss, classification_loss, accuracy_3 = self._eagle_loss(
base_model_hidden_states[:, 1:],
base_model_logits[:, 1:],
eagle_postnorm_h[
@@ -1006,11 +1019,14 @@ def forward(
"Both base_model_loss and eagle_loss are skipped. At least one loss must be computed."
)
+ train_acc = (accuracy_0, accuracy_1, accuracy_2, accuracy_3) if self.training else None
+
return ModelOutput(
loss=loss,
logits=base_model_logits,
past_key_values=past_key_values,
hidden_states=base_model_hidden_states,
+ train_acc=train_acc,
)
def _eagle_loss(
@@ -1034,7 +1050,15 @@ def _eagle_loss(
regression_loss = torch.sum(torch.mean(loss_mask * regression_loss, 2)) / (
loss_mask.sum() + 1e-5
)
- return regression_loss, classification_loss
+ # Compute accuracy
+ base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1)
+ eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1)
+ valid = loss_mask[:, :, 0].bool()
+ correct = (base_predict_tok == eagle_predict_tok) & valid
+ denom = valid.sum().clamp_min(1).float()
+ accuracy = round(correct.sum().float().div(denom).item(), 3)
+
+ return regression_loss, classification_loss, accuracy
@torch.no_grad()
def pseudo_speculative_generate(
@@ -1055,7 +1079,7 @@ def pseudo_speculative_generate(
base_model_hidden_states = base_model_outputs.hidden_states[-1]
base_model_logits = base_model_outputs.logits
- base_token = base_model_logits[:, -1:, :].argmax(dim=-1)
+ base_token = base_model_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device)
# Early return
if steps < 1:
@@ -1132,7 +1156,7 @@ def get_ground_truth(self, input_ids, osl):
input_ids = copy.deepcopy(input_ids).to(torch.cuda.current_device())
for _ in range(osl):
input_id, _ = self.model.pseudo_speculative_generate(input_ids, steps=0)
- input_ids = torch.cat((input_ids, input_id), dim=-1)
+ input_ids = torch.cat((input_ids, input_id.to(input_ids.device)), dim=-1)
if input_id[0, 0] == self.end_token:
break
return input_ids