Skip to content

Commit 557633c

Browse files
authored
Fix: supporting gpt-oss HF eagle (#398)
Signed-off-by: h-guo18 <[email protected]>
1 parent 3a76d28 commit 557633c

File tree

7 files changed

+62
-37
lines changed

7 files changed

+62
-37
lines changed

examples/speculative_decoding/README.md

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,16 @@ pip install -U nvidia-modelopt[hf]
4343
pip install -r requirements.txt
4444
```
4545

46-
We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Download by:
46+
### Data Preparation
47+
48+
We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Prepare data by:
4749

4850
```bash
49-
apt-get update && apt-get install -y git-lfs
50-
git lfs install --system
51-
git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
51+
python prepare_input_conversations/add_daring_anteater.py
5252
```
5353

54+
See [other-datasets](#other-datasets) section for other dataset options and instruction for user-provided data.
55+
5456
## Getting Started: Simplified Workflow
5557

5658
```bash
@@ -71,7 +73,7 @@ For small base models that fit in GPU memory, we can collocate them with draft m
7173
```bash
7274
./launch_train.sh --model $BASE_MODEL \
7375
--output_dir $OUTPUT_DIR \
74-
--data Daring-Anteater/train.jsonl \
76+
--data input_conversations/daring-anteater.jsonl \
7577
--num_gpu $NUM_GPU \
7678
--num_epochs $NUM_EPOCH \
7779
--eagle_config eagle_config.json
@@ -91,7 +93,7 @@ We support two backends for generating base model hidden states. For better effc
9193
```bash
9294
python collect_hidden_states/compute_hidden_states_trtllm.py \
9395
--model $BASE_MODEL \
94-
--input-file Daring-Anteater/train.jsonl \
96+
--input-file input_conversations/daring-anteater.jsonl \
9597
--output-dir $HIDDEN_STATES_DIR
9698
```
9799

@@ -102,7 +104,7 @@ Alternatively, you can generate the same hidden states with HF:
102104
```bash
103105
python collect_hidden_states/compute_hidden_states_hf.py \
104106
--model $BASE_MODEL \
105-
--input-file Daring-Anteater/train.jsonl \
107+
--input-file input_conversations/daring-anteater.jsonl \
106108
--output-dir $HIDDEN_STATES_DIR
107109
```
108110

@@ -130,7 +132,7 @@ For online training checkpoints, we can run in-framework evaluation on MT-bench:
130132
python ar_validate.py --model_path $ONLINE_CKPT
131133
```
132134

133-
Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
135+
**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks.
134136

135137
## Export
136138

@@ -183,6 +185,28 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE
183185
184186
## Advanced Usage
185187
188+
### Other Datasets
189+
190+
In addition to `daring-anteater`, we provide scripts for adding several other commonly used datasets in `prepare_input_conversations`:
191+
192+
```text
193+
prepare_input_conversations/
194+
├── add_daring_anteater.py
195+
├── add_mtbench.py
196+
├── add_sharegpt.py
197+
├── add_ultrachat.py
198+
└── example_make_prompt_dataset.sh
199+
```
200+
201+
To use your own datasets, please preprocess your data into a `.jsonl` file with each line in the format:
202+
203+
```json
204+
{
205+
"conversation_id": <unique id>,
206+
"conversations": [{"role":<user or assistant>, "content":<content>}]
207+
}
208+
```
209+
186210
### Data Synthesis
187211

188212
To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model.
@@ -199,7 +223,7 @@ Note: Add `--quantization=modelopt` flag for quantized models.
199223
Then, we generate conversations with the base model using prompts from Daring-Anteater:
200224

201225
```bash
202-
python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl
226+
python server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl
203227
```
204228

205229
To add a system prompt, use the `--system_prompt <system_prompt_text>` argument.
@@ -211,7 +235,7 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d
211235
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
212236

213237
```bash
214-
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
238+
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
215239
```
216240

217241
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.

examples/speculative_decoding/eagle_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
try:
3131
import wandb
32-
33-
wandb.init()
3432
except ImportError:
3533
wandb = None
3634

@@ -397,6 +395,8 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
397395
class ARValidationCallback(TrainerCallback):
398396
def __init__(self, ar_validate_steps: int = 1000):
399397
self.ar_validate_steps = ar_validate_steps
398+
if wandb:
399+
wandb.init()
400400

401401
def on_step_end(self, args, state, control, **kwargs):
402402
if self.ar_validate_steps <= 0:

examples/speculative_decoding/train_eagle3_and_export.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ set -eo pipefail
2020
# Set default values for BASE_MODEL, NUM_GPU, and DATA
2121
BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
2222
NUM_GPU=1
23-
DATA=Daring-Anteater/train.jsonl
23+
DATA=input_conversations/daring-anteater.jsonl
2424

2525
# Parse input arguments --base_model, --num_gpu, and --data
2626
while [[ $# -gt 0 ]]; do

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,18 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
4848
raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")
4949

5050

51-
def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
51+
def spec_opt_only(model: nn.Module):
52+
"""Check if the model have only speculative decoding optimization."""
53+
opt_modes = getattr(model, "_modelopt_state", None)
54+
return (
55+
isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle"
56+
)
57+
58+
59+
def export_spec_ckpt_state_dict(model: nn.Module):
5260
"""Only return the state dict of the draft model in official format and ignore the base model."""
5361
# check the model has only speculative decoding
54-
opt_modes = getattr(model, "_modelopt_state", None)
55-
if (
56-
not isinstance(opt_modes, (list, tuple))
57-
or len(opt_modes) != 1
58-
or opt_modes[0][0] != "eagle"
59-
):
60-
# if there's other opts, return as is
61-
return post_state_dict
62+
assert spec_opt_only(model), "Not purely eagle model."
6263

6364
# Check if the state dict keys match
6465
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
@@ -80,16 +81,9 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
8081
return export_state_dict
8182

8283

83-
def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
84+
def export_spec_ckpt_config(model: nn.Module):
8485
"""Return the config of draft model in official format."""
85-
opt_modes = getattr(model, "_modelopt_state", None)
86-
if (
87-
not isinstance(opt_modes, (list, tuple))
88-
or len(opt_modes) != 1
89-
or opt_modes[0][0] != "eagle"
90-
):
91-
# return as is
92-
return config_data
86+
assert spec_opt_only(model), "Not purely eagle model."
9387

9488
# This is the config keys in official checkpoint.
9589
template_config = {

modelopt/torch/export/unified_export_hf.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import torch
2828
import torch.nn as nn
29+
from safetensors.torch import save_file
2930

3031
from modelopt.torch.quantization import set_quantizer_by_cfg_context
3132
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
@@ -53,7 +54,7 @@
5354
QUANTIZATION_W4A8_AWQ,
5455
QUANTIZATION_W4A8_NVFP4_FP8,
5556
)
56-
from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding
57+
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
5758
from .quant_utils import (
5859
fuse_prequant_layernorm,
5960
get_activation_scaling_factor,
@@ -511,18 +512,24 @@ def export_hf_checkpoint(
511512
"""
512513
export_dir = Path(export_dir)
513514
export_dir.mkdir(parents=True, exist_ok=True)
515+
516+
# NOTE: (hg) Early exit for speculative decoding models
517+
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint
518+
if spec_opt_only(model):
519+
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
520+
with open(f"{export_dir}/config.json", "w") as file:
521+
json.dump(export_spec_ckpt_config(model), file, indent=4)
522+
return
523+
514524
try:
515525
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
516526

517-
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
518527
# Save hf_quant_config.json for backward compatibility
519528
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
520529
json.dump(hf_quant_config, file, indent=4)
521530

522531
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
523532

524-
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
525-
526533
# Save model
527534
model.save_pretrained(
528535
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
@@ -536,8 +543,6 @@ def export_hf_checkpoint(
536543

537544
config_data["quantization_config"] = hf_quant_config
538545

539-
config_data = set_config_if_spec_decoding(model, config_data)
540-
541546
with open(original_config, "w") as file:
542547
json.dump(config_data, file, indent=4)
543548

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,5 @@
4747
"use_mtp_layernorm": False,
4848
"parallel_draft_step": 1,
4949
"has_lm_head": False,
50+
"head_dim": 128,
5051
}

tests/examples/speculative_decoding/test_eagle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_
2727
"intermediate_size": 64,
2828
"num_attention_heads": 2,
2929
"num_key_value_heads": 2,
30+
"head_dim": 64,
3031
}
3132

3233
# Write the tiny config to a temporary file

0 commit comments

Comments
 (0)