Skip to content

Commit 5d9f76e

Browse files
committed
feat: update eagle3 example; add export
Signed-off-by: h-guo18 <[email protected]>
1 parent cf6f1d4 commit 5d9f76e

File tree

12 files changed

+519
-191
lines changed

12 files changed

+519
-191
lines changed

examples/speculative_decoding/README.md

Lines changed: 144 additions & 132 deletions
Large diffs are not rendered by default.

examples/speculative_decoding/ar_validate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828

29-
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None):
29+
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
3030
validator = HFARValidation(model, tokenizer)
3131
num_samples = min(num_samples, len(ds))
3232
ars = []
@@ -54,12 +54,12 @@ def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=No
5454
def main():
5555
parser = argparse.ArgumentParser()
5656
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
57-
parser.add_argument("--steps", type=int, default=1, help="Steps for AR validation")
57+
parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation")
5858
parser.add_argument(
59-
"--osl", type=int, default=100, help="Output sequence length for AR validation"
59+
"--osl", type=int, default=32, help="Output sequence length for AR validation"
6060
)
6161
parser.add_argument(
62-
"--num_samples", type=int, default=20, help="Number of MT-Bench samples to use"
62+
"--num_samples", type=int, default=80, help="Number of MT-Bench samples to use"
6363
)
6464
parser.add_argument(
6565
"--ar_lower_bound",

examples/speculative_decoding/calibrate_draft_vocab.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ def main():
2828
parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer")
2929
parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)")
3030
parser.add_argument(
31-
"--eagle_config",
32-
type=str,
31+
"--draft_vocab_size",
32+
type=int,
3333
required=True,
34-
default="eagle_config.json",
35-
help="Path to eagle_config.json",
34+
help="Draft vocab size",
3635
)
3736
parser.add_argument(
3837
"--calibrate_size",
@@ -45,12 +44,6 @@ def main():
4544
)
4645
args = parser.parse_args()
4746

48-
with open(args.eagle_config) as f:
49-
eagle_config = json.load(f)
50-
if "draft_vocab_size" not in eagle_config:
51-
print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.")
52-
return
53-
5447
print("Calibrating vocab...")
5548
tokenizer = AutoTokenizer.from_pretrained(args.model)
5649
with open(args.data) as f:
@@ -59,7 +52,7 @@ def main():
5952
conversations = conversations[: args.calibrate_size]
6053
conversations = [item for sublist in conversations for item in sublist]
6154

62-
d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"])
55+
d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size)
6356
model_name = os.path.basename(os.path.normpath(args.model))
6457
vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt")
6558
os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
{
2-
"draft_vocab_size": 32000
2+
"rope_scaling": {
3+
"factor": 32.0,
4+
"low_freq_factor": 1.0,
5+
"high_freq_factor": 4.0,
6+
"original_max_position_embeddings": 8192,
7+
"rope_type": "llama3"
8+
},
9+
"initializer_range": 0.02
310
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
18+
import torch
19+
from transformers import AutoModelForCausalLM
20+
21+
import modelopt.torch.opt as mto
22+
from modelopt.torch.export import export_hf_checkpoint
23+
24+
25+
def parse_args():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--model_path", type=str, default="")
28+
parser.add_argument("--export_path", type=str, default="")
29+
return parser.parse_args()
30+
31+
32+
mto.enable_huggingface_checkpointing()
33+
34+
args = parse_args()
35+
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
36+
with torch.inference_mode():
37+
export_hf_checkpoint(
38+
model, # The quantized model.
39+
export_dir=args.export_path, # The directory where the exported files will be stored.
40+
)
41+
print(f"Exported checkpoint to {args.export_path}")

examples/speculative_decoding/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
import modelopt.torch.speculative as mtsp
4848
from modelopt.torch.utils import print_rank_0
4949

50+
try:
51+
import wandb
52+
53+
wandb.init()
54+
except ImportError:
55+
wandb = None
56+
5057
torch.manual_seed(0)
5158
mto.enable_huggingface_checkpointing()
5259

@@ -170,6 +177,8 @@ def train():
170177
{
171178
"hidden_size": model.config.hidden_size,
172179
"vocab_size": model.config.vocab_size,
180+
# we also overwrite max_pos_embedding for deployment compatibility
181+
"max_position_embeddings": model.config.max_position_embeddings,
173182
"draft_vocab_size": custom_config["draft_vocab_size"]
174183
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
175184
else model.config.vocab_size,
@@ -213,6 +222,8 @@ def on_step_end(self, args, state, control, **kwargs):
213222
device=kwargs["model"].device,
214223
)
215224
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
225+
if wandb:
226+
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
216227
return control
217228

218229
trainer = Trainer(

examples/speculative_decoding/server_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
parser.add_argument(
4747
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
4848
)
49-
parser.add_argument("--chat", action="store_true", help="Use chat mode")
49+
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
5050
parser.add_argument("--model", type=str, default="model", help="Model name")
5151
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
5252
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/bin/bash
2+
3+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
set -eo pipefail
19+
20+
# Set default values for BASE_MODEL, NUM_GPU, and DATA
21+
BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
22+
NUM_GPU=1
23+
DATA=Daring-Anteater/train.jsonl
24+
25+
# Parse input arguments --base-model, --num_gpu, and --data
26+
while [[ $# -gt 0 ]]; do
27+
key="$1"
28+
case $key in
29+
--base_model)
30+
BASE_MODEL="$2"
31+
shift; shift
32+
;;
33+
--num_gpu)
34+
NUM_GPU="$2"
35+
shift; shift
36+
;;
37+
--data)
38+
DATA="$2"
39+
shift; shift
40+
;;
41+
*)
42+
echo "Unknown argument: $1"
43+
exit 1
44+
;;
45+
esac
46+
done
47+
48+
49+
if [[ "$NUM_GPU" == 1 ]]; then
50+
export CUDA_VISIBLE_DEVICES=0
51+
else
52+
# Export as 0,1,...,N-1 for NUM_GPU GPUs
53+
export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPU-1)))
54+
fi
55+
56+
MODEL_BASENAME=$(basename "$BASE_MODEL")
57+
58+
echo "==== [1/3] Training draft model ===="
59+
OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
60+
./launch_train.sh --model $BASE_MODEL \
61+
--output_dir $OUTPUT_DIR \
62+
--data $DATA \
63+
--num_gpu $NUM_GPU \
64+
--num_epochs 2 \
65+
--eagle_config eagle_config.json
66+
67+
echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ===="
68+
python ar_validate.py --model_path $OUTPUT_DIR
69+
70+
echo "==== [3/3] Exporting checkpoint to deployment format ===="
71+
EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
72+
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH

modelopt/torch/export/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@
1919

2020
with import_plugin("megatron_importer"):
2121
from .megatron_importer import *
22+
with import_plugin("transformers"):
23+
from .hf_spec_export import *
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Modifiy stated_dict and config for exporting speculative decoding in official format."""
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
from modelopt.torch.speculative.plugins.transformers import HFEagleModel
22+
23+
SPECULATIVE_DECODING_MODES = ["eagle", "medusa"]
24+
25+
EALGE_MODELOPT_TO_OFFICIAL = {
26+
"required": {
27+
"layers.0.self_attn.q_proj.weight": "midlayer.self_attn.q_proj.weight",
28+
"layers.0.self_attn.k_proj.weight": "midlayer.self_attn.k_proj.weight",
29+
"layers.0.self_attn.v_proj.weight": "midlayer.self_attn.v_proj.weight",
30+
"layers.0.self_attn.o_proj.weight": "midlayer.self_attn.o_proj.weight",
31+
"layers.0.mlp.gate_proj.weight": "midlayer.mlp.gate_proj.weight",
32+
"layers.0.mlp.up_proj.weight": "midlayer.mlp.up_proj.weight",
33+
"layers.0.mlp.down_proj.weight": "midlayer.mlp.down_proj.weight",
34+
"hidden_norm.weight": "midlayer.hidden_norm.weight",
35+
"input_embeds_norm.weight": "midlayer.input_layernorm.weight",
36+
"layers.0.post_attention_layernorm.weight": "midlayer.post_attention_layernorm.weight",
37+
"norm.weight": "norm.weight",
38+
"fc.weight": "fc.weight",
39+
},
40+
"optional": {
41+
"d2t": "d2t",
42+
"eagle_lm_head.weight": "lm_head.weight",
43+
},
44+
}
45+
46+
47+
def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
48+
"""Check if the state dict keys match."""
49+
draft_keys = set(draft_model.state_dict().keys())
50+
for required_key in required_items:
51+
if required_key not in draft_keys:
52+
raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")
53+
54+
55+
def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
56+
"""Only return the state dict of the draft model in official format and ignore the base model."""
57+
# check the model has only speculative decoding
58+
opt_modes = model._modelopt_state
59+
if len(opt_modes) != 1 or opt_modes[0][0] != "eagle":
60+
# if there's other opts, return as is
61+
return post_state_dict
62+
63+
assert isinstance(model, HFEagleModel)
64+
# Check if the state dict keys match
65+
_check_state_dict_keys_match(model.eagle_module, EALGE_MODELOPT_TO_OFFICIAL["required"])
66+
67+
# Convert key names and save the state dict
68+
export_state_dict = {}
69+
for ours_key, export_key in {
70+
**EALGE_MODELOPT_TO_OFFICIAL["required"],
71+
**EALGE_MODELOPT_TO_OFFICIAL["optional"],
72+
}.items():
73+
if ours_key in model.eagle_module.state_dict():
74+
export_state_dict[export_key] = model.eagle_module.state_dict()[ours_key]
75+
76+
# TODO: (hg) this is a temp fix. Find cleaner way to do this.
77+
if "eagle_lm_head.weight" not in model.eagle_module.state_dict():
78+
export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"]
79+
80+
return export_state_dict
81+
82+
83+
def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
84+
"""Return the config of draft model in official format."""
85+
if len(model._modelopt_state) != 1 or model._modelopt_state[0][0] != "eagle":
86+
# return as is
87+
return config_data
88+
89+
assert isinstance(model, HFEagleModel)
90+
91+
# This is the config keys in official checkpoint.
92+
template_config = {
93+
"architectures": ["LlamaForCausalLM"],
94+
"bos_token_id": None,
95+
"eos_token_id": None,
96+
"hidden_act": None,
97+
"hidden_size": None,
98+
"initializer_range": None,
99+
"intermediate_size": None,
100+
"max_position_embeddings": None,
101+
"model_type": "llama",
102+
"num_attention_heads": None,
103+
"num_key_value_heads": None,
104+
"num_hidden_layers": None,
105+
"pad_token_id": None,
106+
"rms_norm_eps": None,
107+
"tie_word_embeddings": False,
108+
"torch_dtype": None,
109+
"transformers_version": None,
110+
"use_cache": None,
111+
"vocab_size": None,
112+
"draft_vocab_size": None,
113+
"rope_scaling": None,
114+
"attention_bias": None,
115+
"attention_dropout": None,
116+
"head_dim": None,
117+
"mlp_bias": None,
118+
"pretraining_tp": None,
119+
"rope_theta": None,
120+
"eagle_config": {
121+
"eagle_aux_hidden_state_layer_ids": None,
122+
"use_aux_hidden_state": None,
123+
"use_input_layernorm_in_first_layer": None,
124+
"use_last_layernorm": None,
125+
"use_mtp_layernorm": None,
126+
},
127+
}
128+
129+
def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module):
130+
if getattr(model.eagle_config, key, None) is not None:
131+
return getattr(model.eagle_config, key)
132+
elif getattr(model.config, key, None) is not None:
133+
return getattr(model.config, key)
134+
else:
135+
return None
136+
137+
for key in template_config:
138+
value = template_config[key]
139+
if isinstance(value, dict):
140+
# for eagle config, we find it in model.eagle_config
141+
for sub_key in value:
142+
value[sub_key] = _get_config_from_eagle_config_or_base_config(sub_key, model)
143+
elif value is None:
144+
# First, we try to load fron eagle config.
145+
new_value = _get_config_from_eagle_config_or_base_config(key, model)
146+
# If the value is a torch.dtype, we convert to string for serialization.
147+
if isinstance(new_value, torch.dtype):
148+
new_value = str(new_value).replace("torch.", "")
149+
template_config[key] = new_value
150+
151+
return template_config

0 commit comments

Comments
 (0)