Skip to content

Commit 264adbb

Browse files
committed
test weight quantizer too
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 42519cc commit 264adbb

File tree

9 files changed

+266
-172
lines changed

9 files changed

+266
-172
lines changed

examples/speculative_decoding/README.md

Lines changed: 113 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,
1515
| **Section** | **Description** | **Jump To** |
1616
| :------------: | :------------: | :------------: |
1717
| Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] |
18-
| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
19-
| Complete Workflow | Full example with configurable training pipeline | \[[Link](#complete-workflow)\] |
18+
| Simplified Workflow | Train, evaluate, and export EAGLE model with one-line command | \[[Link](#getting-started-simplified-workflow)\] |
19+
| Online Training | Train draft model alongside base model in GPU memory | \[[Link](#training-draft-model-with-online-base-model)\] |
20+
| Offline Training | Train draft model using pre-computed hidden states | \[[Link](#training-draft-model-with-offline-base-model)\] |
21+
| After Training | Evaluation, export and deployment | \[[Link](#model-validation)\] |
22+
| Advanced Usage | Data synthesis, vocab compression, and configuration | \[[Link](#advanced-usage)\] |
2023
| Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] |
2124
| Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] |
2225
| Resources | Extra links to relevant resources | \[[Link](#resources)\] |
@@ -61,13 +64,113 @@ This one-line command runs a minimal example workflow of training and exporting
6164
- Evaluates the acceptance rate on [MT-Bench](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)
6265
- Exports a checkpoint ready for deployment
6366

64-
## Complete Workflow
67+
## Training Draft Model with Online Base Model
6568

66-
This section presents a more comprehensive example for customizing speculative decoding training with Modelopt, including optional steps to enhance training quality and efficiency.
69+
For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command:
6770

68-
### (Optional) Data Synthesis
71+
```bash
72+
./launch_train.sh --model $BASE_MODEL \
73+
--output_dir $OUTPUT_DIR \
74+
--data Daring-Anteater/train.jsonl \
75+
--num_gpu $NUM_GPU \
76+
--num_epochs $NUM_EPOCH \
77+
--eagle_config eagle_config.json
78+
```
79+
80+
This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details.
81+
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
82+
83+
## Training Draft Model with Offline Base Model
84+
85+
For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of storage depending on dataset size.
86+
87+
First, dump the base model's hidden states with the following command:
88+
89+
```bash
90+
python collect_hidden_states/compute_hidden_states_hf.py \
91+
--model $BASE_MODEL \
92+
--input-file Daring-Anteater/train.jsonl \
93+
--output-dir $HIDDEN_STATES_DIR
94+
```
95+
96+
See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation.
97+
98+
Then, train draft model with `--offline-data` argument:
99+
100+
```bash
101+
./launch_train.sh --model $BASE_MODEL \
102+
--output_dir $OUTPUT_DIR \
103+
--data $DATA \
104+
--num_gpu $NUM_GPU \
105+
--num_epochs $NUM_EPOCH \
106+
--eagle_config eagle_config.json \
107+
--offline-data $HIDDEN_STATES_DIR
108+
```
109+
110+
## Model Validation
111+
112+
After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
113+
114+
```bash
115+
python ar_validate.py --model_path $OUTPUT_DIR
116+
```
117+
118+
Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
119+
120+
## Export
121+
122+
```bash
123+
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
124+
```
125+
126+
This exports the model from a ModelOpt checkpoint to a deployment-compatible format.
127+
128+
## Deployment
129+
130+
The exported checkpoint can be deployed on TRT-LLM or SGLang.
131+
132+
### TRT-LLM
133+
134+
To serve the checkpoint with TRT-LLM, run trtllm-serve with:
135+
136+
```bash
137+
trtllm-serve <base_model_checkpoint> --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml
138+
```
139+
140+
, with `extra-llm-api-config.yml` being
141+
142+
```yaml
143+
enable_attention_dp: false
144+
disable_overlap_scheduler: true
145+
enable_autotuner: false
146+
147+
cuda_graph_config:
148+
max_batch_size: 1
149+
150+
speculative_config:
151+
decoding_type: Eagle
152+
max_draft_len: 3
153+
speculative_model_dir: <draft_model_checkpoint>
154+
155+
kv_cache_config:
156+
enable_block_reuse: false
157+
```
158+
159+
Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage.
160+
161+
### SGLang
69162
70-
To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data, ensuring that the draft model’s output distribution closely aligns with that of the base model.
163+
Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage.
164+
165+
### Deploying Quantized model
166+
167+
See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).
168+
169+
## Advanced Usage
170+
171+
### Data Synthesis
172+
173+
To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model.
71174
72175
To prepare such data, we launch an inference server with the base model:
73176
@@ -78,7 +181,7 @@ vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000
78181

79182
Note: Add `--quantization=modelopt` flag for quantized models.
80183

81-
Then, we generate conversations with base model and prompts from Daring-Anteater:
184+
Then, we generate conversations with the base model using prompts from Daring-Anteater:
82185

83186
```bash
84187
python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl
@@ -88,7 +191,7 @@ To add a system prompt, use the `--system_prompt <system_prompt_text>` argument.
88191

89192
For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support.
90193

91-
### (Optional) Draft Vocabulary Compression
194+
### Draft Vocabulary Compression
92195

93196
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
94197

@@ -98,7 +201,7 @@ python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data
98201

99202
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.
100203

101-
### (Optional) Configuring Draft Model
204+
### Configuring Draft Model
102205

103206
For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`:
104207

@@ -108,7 +211,7 @@ For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](htt
108211
}
109212
```
110213

111-
### Training Draft Model with Modelopt
214+
### Interact with `modelopt.torch.speculative`
112215

113216
`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps:
114217
First, load the base model and tokenizer from Hugging Face:
@@ -162,78 +265,6 @@ trainer.save_state()
162265
trainer.save_model("<path to the output directory>")
163266
```
164267

165-
We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by:
166-
167-
```bash
168-
./launch_train.sh --model $BASE_MODEL \
169-
--output_dir $OUTPUT_DIR \
170-
--data $DATA \
171-
--num_gpu $NUM_GPU \
172-
--num_epochs 10 \
173-
--eagle_config eagle_config.json #This is where we optionally overwrite default eagle configs
174-
```
175-
176-
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
177-
178-
### Model Validation
179-
180-
After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
181-
182-
```bash
183-
python ar_validate.py --model_path $OUTPUT_DIR
184-
```
185-
186-
Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
187-
188-
### Export
189-
190-
```bash
191-
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
192-
```
193-
194-
This exports the model from a ModelOpt checkpoint to a deployment‑compatible format.
195-
196-
### Deployment
197-
198-
The exported checkpoint can be deployed on TRT-LLM or SGLang.
199-
200-
#### TRT-LLM
201-
202-
To serve the checkpoint with trtllm, run trtllm-serve with:
203-
204-
```bash
205-
trtllm-serve <base_model_checkpoint> --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml
206-
```
207-
208-
, with `extra-llm-api-config.yml` being
209-
210-
```yaml
211-
enable_attention_dp: false
212-
disable_overlap_scheduler: true
213-
enable_autotuner: false
214-
215-
cuda_graph_config:
216-
max_batch_size: 1
217-
218-
speculative_config:
219-
decoding_type: Eagle
220-
max_draft_len: 3
221-
speculative_model_dir: <draft_model_checkpoint>
222-
223-
kv_cache_config:
224-
enable_block_reuse: false
225-
```
226-
227-
Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage.
228-
229-
#### SGLang
230-
231-
Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage.
232-
233-
#### Deploying Quantized model
234-
235-
See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).
236-
237268
## Support Matrix
238269

239270
| Model | Medusa | EAGLE1/2 | EAGLE3 |

examples/speculative_decoding/launch_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ if [[ "$OFFLINE_DATA_PATH" != "" ]]; then
129129
echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory."
130130
exit 1
131131
else
132-
OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH"
132+
OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1"
133133
fi
134134
else
135135
OFFLINE_TRAINING_ARGS=""

modelopt/torch/distill/plugins/megatron.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class DistillationConfig:
5959
logit_kl_temperature: Temperature for the logit KL-divergence loss.
6060
"""
6161

62-
intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list)
62+
intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list)
6363
logit_layers: tuple[str, str] = ("output_layer", "output_layer")
6464
skip_lm_loss: bool = True
6565
kd_loss_scale: float = 1.0
@@ -69,12 +69,28 @@ class DistillationConfig:
6969

7070
def __post_init__(self):
7171
assert len(self.logit_layers) == 2, f"{self.logit_layers=}"
72-
assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), (
72+
assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), (
7373
f"{self.intermediate_layer_pairs=}"
7474
)
7575
assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}"
7676
assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}"
7777

78+
@staticmethod
79+
def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]:
80+
"""Parse an intermediate entry into a student layer, teacher layer, and loss function."""
81+
if len(entry) == 3:
82+
student_layer, teacher_layer, loss_fn_name = entry
83+
if loss_fn_name == "cosine":
84+
loss_fn = HiddenStateCosineLoss
85+
elif loss_fn_name == "mse":
86+
loss_fn = MSELoss
87+
else:
88+
raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}")
89+
else:
90+
student_layer, teacher_layer = entry
91+
loss_fn = HiddenStateCosineLoss # default to cosine loss
92+
return student_layer, teacher_layer, loss_fn
93+
7894

7995
def load_distillation_config(
8096
config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig"
@@ -105,7 +121,8 @@ def load_distillation_config(
105121
# NOTE: Projection layer shared among intermediate layer pairs.
106122
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)
107123

108-
for student_layer, teacher_layer in cfg.intermediate_layer_pairs:
124+
for entry in cfg.intermediate_layer_pairs:
125+
student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry)
109126
if parallel_state.get_tensor_and_context_parallel_rank() == 0:
110127
logger.info(
111128
"Distillation: Adding intermediate loss between"
@@ -114,7 +131,7 @@ def load_distillation_config(
114131
)
115132
student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg)
116133
teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg)
117-
criterion[(student_layer, teacher_layer)] = HiddenStateCosineLoss(
134+
criterion[(student_layer, teacher_layer)] = loss_fn(
118135
student_cfg, projection_layer=projection_layer
119136
)
120137

@@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
202219
predictions, targets = self.pre_forward(predictions, targets)
203220

204221
loss = F.mse_loss(predictions, targets, reduction="none")
205-
loss = loss.sum(dim=-1)
222+
loss = loss.mean(dim=-1)
206223

207-
return self.post_forward(loss)
224+
return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel)
208225

209226

210227
class HiddenStateCosineLoss(BaseLoss):

0 commit comments

Comments
 (0)