Skip to content

Commit cf17899

Browse files
committed
address review comments
Signed-off-by: h-guo18 <[email protected]>
1 parent 52dbdec commit cf17899

File tree

7 files changed

+112
-57
lines changed

7 files changed

+112
-57
lines changed

examples/speculative_decoding/README.md

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
[![Documentation](https://img.shields.io/badge/Docs-TensorRT--Model--Optimizer-blue?logo=readthedocs&style=flat-square)](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/5_speculative_decoding.html)
44

5-
Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving throughput.
5+
Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving generation speed.
66

7-
This folder contains end-to-end runnable speculative decoding fine-tuning pipeline where Llama3.2-1B from huggingface is trained on Daring-Anteater dataset.
7+
This folder contains an end-to-end runnable speculative decoding finetuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the DaringAnteater dataset.
88

9-
This example focus on training with HF. To train with Megatron-LM, please refer to [this link](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt) in Megatron-LM repo.
9+
This example focuses on training with Hugging Face. To train with MegatronLM, see the [Megatron‑LM example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt).
1010

1111
## Contents
1212

@@ -15,9 +15,9 @@ This example focus on training with HF. To train with Megatron-LM, please refer
1515
| **Section** | **Description** | **Jump To** |
1616
| :------------: | :------------: | :------------: |
1717
| Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] |
18-
| Simplified Workflow | Train, evaluate and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
19-
| Complete Workflow | Full example with configurable traininig pipeline | \[[Link](#support-matrix)\] |
20-
| Support Matrix | Supported models for speculative decoding training | \[[Link](#deployment)\] |
18+
| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
19+
| Complete Workflow | Full example with configurable training pipeline | \[[Link](#complete-workflow)\] |
20+
| Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] |
2121
| Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] |
2222
| Resources | Extra links to relevant resources | \[[Link](#resources)\] |
2323

@@ -75,19 +75,23 @@ Then, we generate conversations with base model and prompts from Daring-Anteater
7575
python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl
7676
```
7777

78+
To add a system prompt, use the `--system_prompt <system_prompt_text>` argument.
79+
80+
For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support.
81+
7882
### (Optional) Draft Vocabulary Compression
7983

8084
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
8185

8286
```bash
83-
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl ----draft_vocab_size 32000 --save_dir draft_vocab_cache
87+
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
8488
```
8589

8690
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later.
8791

8892
### (Optional) Configuring Draft Model
8993

90-
For eagle1 and eagle3 we provide an [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18) in modelopt. User can overwrite default settings by providing additional json dict. In this example, we overwrite the `draft_vocab_size` by in `eagle_config.json`:
94+
For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`:
9195

9296
```json
9397
{
@@ -97,8 +101,8 @@ For eagle1 and eagle3 we provide an [default model architecture config](https://
97101

98102
### Training Draft Model with Modelopt
99103

100-
`main.py` provides a example for converting a base HF model for speculative decoding and training it. It consists of a few simple steps:
101-
First, load base model and tokenzier from hugginface:
104+
`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps:
105+
First, load the base model and tokenizer from Hugging Face:
102106

103107
```python
104108
model = transformers.AutoModelForCausalLM.from_pretrained(
@@ -116,7 +120,7 @@ config = {
116120
}[training_args.mode]["config"]
117121

118122
# overwrite config with custom config
119-
config["eagle_architecture_config"].update({"<overwrite_kyes>": "<overwrite_values>"})
123+
config["eagle_architecture_config"].update({"<overwrite_keys>": "<overwrite_values>"})
120124

121125
# Mandatory: hidden size, vocab size and max position embeddings must match base model
122126
config["eagle_architecture_config"].update(
@@ -128,7 +132,7 @@ config["eagle_architecture_config"].update(
128132
)
129133
```
130134

131-
Then, we convert model to a speculative deocoding model:
135+
Then, we convert model to a speculative decoding model:
132136

133137
```python
134138
mtsp.convert(model, [("eagle", config)])
@@ -149,15 +153,15 @@ trainer.save_state()
149153
trainer.save_model("<path to the output directory>")
150154
```
151155

152-
We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch the training with huggingface accelrate in `launch_train.sh`, which can be runned by:
156+
We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by:
153157

154158
```bash
155159
./launch_train.sh --model $BASE_MODEL \
156160
--output_dir $OUTPUT_DIR \
157161
--data $DATA \
158162
--num_gpu $NUM_GPU \
159163
--num_epochs 10 \
160-
--eagle_config eagle_config.json #This is where we overwrite default eagle configs
164+
--eagle_config eagle_config.json #This is where we optionally overwrite default eagle configs
161165
```
162166

163167
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
@@ -166,27 +170,27 @@ The saved modelopt checkpoint is similar in architecture to HF models. It can be
166170

167171
After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
168172

169-
```python
173+
```bash
170174
python ar_validate.py --model_path $OUTPUT_DIR
171175
```
172176

173177
Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
174178

175179
### Export
176180

177-
```python
181+
```bash
178182
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
179183
```
180184

181-
This will export the model from a modelopt checkpoint to a deployment-compatible formart.
185+
This exports the model from a ModelOpt checkpoint to a deploymentcompatible format.
182186

183187
### Deployment
184188

185189
The exported checkpoint can be deployed on TRT-LLM or vLLM.
186190

187191
#### TRT-LLM
188192

189-
To serve the checkpoint with trtllm, we can run trtllm-serve with:
193+
To serve the checkpoint with trtllm, run trtllm-serve with:
190194

191195
```bash
192196
trtllm-serve <base_model_checkpoint> --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml
@@ -233,7 +237,7 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE
233237
234238
## Speculation Module Checkpoints
235239
236-
Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
240+
Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - NVIDIA TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]
237241
Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang)!\
238242
More models coming soon!
239243
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SLURM Prepare Data
2+
3+
For basic parallelization of synthetic data generation we provide some SLURM support.
4+
Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.
5+
6+
Example of allocating 4 nodes for 120 minutes
7+
8+
```sh
9+
salloc -N4 -A <account> -p <partition> -J <account>-synthetic:data-gen -t 120
10+
```
11+
12+
Create shards of some given size
13+
14+
```sh
15+
python3 distributed_generate/sharding_utils.py --input_path /data/train.jsonl --output_dir /data/train/ --max_lines_per_shard 10000
16+
```
17+
18+
Run workers on SLURM
19+
20+
```sh
21+
bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 0 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
22+
```
23+
24+
`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
25+
This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
26+
In this case, the first 40 shards of data will be processed.
27+
To process the next 40 shards
28+
29+
```sh
30+
bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
31+
```
32+
33+
To combine the shards back
34+
35+
```sh
36+
python3 distributed_generate/sharding_utils.py --input_dir /data/output/ --output_path /data/output.jsonl --combine
37+
```

examples/speculative_decoding/export_hf_checkpoint.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
"""Export a HF checkpoint (with ModelOpt state) for deployment."""
17+
1618
import argparse
1719

1820
import torch
@@ -23,16 +25,21 @@
2325

2426

2527
def parse_args():
26-
parser = argparse.ArgumentParser()
27-
parser.add_argument("--model_path", type=str, default="")
28-
parser.add_argument("--export_path", type=str, default="")
28+
parser = argparse.ArgumentParser(
29+
description="Export a HF checkpoint (with ModelOpt state) for deployment."
30+
)
31+
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
32+
parser.add_argument(
33+
"--export_path", type=str, default="Destination directory for exported files."
34+
)
2935
return parser.parse_args()
3036

3137

3238
mto.enable_huggingface_checkpointing()
3339

3440
args = parse_args()
3541
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
42+
model.eval()
3643
with torch.inference_mode():
3744
export_hf_checkpoint(
3845
model, # The quantized model.

examples/speculative_decoding/train_eagle3_and_export.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
2222
NUM_GPU=1
2323
DATA=Daring-Anteater/train.jsonl
2424

25-
# Parse input arguments --base-model, --num_gpu, and --data
25+
# Parse input arguments --base_model, --num_gpu, and --data
2626
while [[ $# -gt 0 ]]; do
2727
key="$1"
2828
case $key in
@@ -50,13 +50,15 @@ if [[ "$NUM_GPU" == 1 ]]; then
5050
export CUDA_VISIBLE_DEVICES=0
5151
else
5252
# Export as 0,1,...,N-1 for NUM_GPU GPUs
53-
export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPU-1)))
53+
devs="$(seq -s, 0 $((NUM_GPU-1)))"
54+
export CUDA_VISIBLE_DEVICES="$devs"
5455
fi
5556

5657
MODEL_BASENAME=$(basename "$BASE_MODEL")
5758

5859
echo "==== [1/3] Training draft model ===="
5960
OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
61+
mkdir -p "$(dirname "$OUTPUT_DIR")"
6062
./launch_train.sh --model $BASE_MODEL \
6163
--output_dir $OUTPUT_DIR \
6264
--data $DATA \
@@ -69,4 +71,5 @@ python ar_validate.py --model_path $OUTPUT_DIR
6971

7072
echo "==== [3/3] Exporting checkpoint to deployment format ===="
7173
EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
74+
mkdir -p "$(dirname "$EXPORT_PATH")"
7275
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Modifiy stated_dict and config for exporting speculative decoding in official format."""
16+
"""Modify state_dict and config for exporting speculative decoding in official format."""
1717

1818
import torch
1919
import torch.nn as nn
20+
import transformers
2021

2122
from modelopt.torch.speculative.plugins.transformers import HFEagleModel
2223

23-
SPECULATIVE_DECODING_MODES = ["eagle", "medusa"]
24-
25-
EALGE_MODELOPT_TO_OFFICIAL = {
24+
EAGLE_MODELOPT_TO_OFFICIAL = {
2625
"required": {
2726
"layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight",
2827
"layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight",
@@ -55,26 +54,31 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
5554
def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
5655
"""Only return the state dict of the draft model in official format and ignore the base model."""
5756
# check the model has only speculative decoding
58-
opt_modes = model._modelopt_state
59-
if len(opt_modes) != 1 or opt_modes[0][0] != "eagle":
57+
opt_modes = getattr(model, "_modelopt_state", None)
58+
if (
59+
not isinstance(opt_modes, (list, tuple))
60+
or len(opt_modes) != 1
61+
or opt_modes[0][0] != "eagle"
62+
):
6063
# if there's other opts, return as is
6164
return post_state_dict
6265

6366
assert isinstance(model, HFEagleModel)
6467
# Check if the state dict keys match
65-
_check_state_dict_keys_match(model.eagle_module, EALGE_MODELOPT_TO_OFFICIAL["required"])
68+
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
6669

6770
# Convert key names and save the state dict
71+
eagle_state = model.eagle_module.state_dict()
6872
export_state_dict = {}
6973
for ours_key, export_key in {
70-
**EALGE_MODELOPT_TO_OFFICIAL["required"],
71-
**EALGE_MODELOPT_TO_OFFICIAL["optional"],
74+
**EAGLE_MODELOPT_TO_OFFICIAL["required"],
75+
**EAGLE_MODELOPT_TO_OFFICIAL["optional"],
7276
}.items():
73-
if ours_key in model.eagle_module.state_dict():
74-
export_state_dict[export_key] = model.eagle_module.state_dict()[ours_key]
77+
if ours_key in eagle_state:
78+
export_state_dict[export_key] = eagle_state[ours_key]
7579

7680
# TODO: (hg) this is a temp fix. Find cleaner way to do this.
77-
if "eagle_lm_head.weight" not in model.eagle_module.state_dict():
81+
if "eagle_lm_head.weight" not in eagle_state:
7882
export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
7983

8084
return export_state_dict
@@ -90,7 +94,7 @@ def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
9094

9195
# This is the config keys in official checkpoint.
9296
template_config = {
93-
"architectures": ["LlamaForCausalLM"],
97+
"architectures": ["LlamaForCausalLMEagle3"],
9498
"bos_token_id": None,
9599
"eos_token_id": None,
96100
"hidden_act": None,
@@ -106,7 +110,7 @@ def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
106110
"rms_norm_eps": None,
107111
"tie_word_embeddings": False,
108112
"torch_dtype": None,
109-
"transformers_version": None,
113+
"transformers_version": transformers.__version__,
110114
"use_cache": None,
111115
"vocab_size": None,
112116
"draft_vocab_size": None,

modelopt/torch/export/unified_export_hf.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,8 @@ def _export_hf_checkpoint(
496496

497497
def _quant_applied(hf_quant_config: dict) -> bool:
498498
"""Check if any quantization is applied."""
499-
return not (
500-
hf_quant_config["quantization"]["quant_algo"] == QUANTIZATION_NONE
501-
and not hf_quant_config["quantization"]["quantized_layers"]
502-
)
499+
q = hf_quant_config.get("quantization", {})
500+
return not (q.get("quant_algo") == QUANTIZATION_NONE and not q.get("quantized_layers"))
503501

504502

505503
def export_hf_checkpoint(
@@ -521,11 +519,10 @@ def export_hf_checkpoint(
521519
try:
522520
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
523521

524-
# When there's no quantization applied, we avoid saving hf_quant_config.json for compatibility
525-
if _quant_applied(hf_quant_config):
526-
# Save hf_quant_config.json for backward compatibility
527-
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
528-
json.dump(hf_quant_config, file, indent=4)
522+
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
523+
# Save hf_quant_config.json for backward compatibility
524+
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
525+
json.dump(hf_quant_config, file, indent=4)
529526

530527
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
531528

0 commit comments

Comments
 (0)