Skip to content

Commit 726b848

Browse files
committed
feat: update eagle3 example; add export
Signed-off-by: h-guo18 <[email protected]>
1 parent d5c88e7 commit 726b848

File tree

13 files changed

+519
-349
lines changed

13 files changed

+519
-349
lines changed

examples/speculative_decoding/README.md

Lines changed: 144 additions & 132 deletions
Large diffs are not rendered by default.

examples/speculative_decoding/ar_validate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828

29-
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None):
29+
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
3030
validator = HFARValidation(model, tokenizer)
3131
num_samples = min(num_samples, len(ds))
3232
ars = []
@@ -54,12 +54,12 @@ def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=No
5454
def main():
5555
parser = argparse.ArgumentParser()
5656
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
57-
parser.add_argument("--steps", type=int, default=1, help="Steps for AR validation")
57+
parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation")
5858
parser.add_argument(
59-
"--osl", type=int, default=100, help="Output sequence length for AR validation"
59+
"--osl", type=int, default=32, help="Output sequence length for AR validation"
6060
)
6161
parser.add_argument(
62-
"--num_samples", type=int, default=20, help="Number of MT-Bench samples to use"
62+
"--num_samples", type=int, default=80, help="Number of MT-Bench samples to use"
6363
)
6464
parser.add_argument(
6565
"--ar_lower_bound",

examples/speculative_decoding/calibrate_draft_vocab.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ def main():
2828
parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer")
2929
parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)")
3030
parser.add_argument(
31-
"--eagle_config",
32-
type=str,
31+
"--draft_vocab_size",
32+
type=int,
3333
required=True,
34-
default="eagle_config.json",
35-
help="Path to eagle_config.json",
34+
help="Draft vocab size",
3635
)
3736
parser.add_argument(
3837
"--calibrate_size",
@@ -45,12 +44,6 @@ def main():
4544
)
4645
args = parser.parse_args()
4746

48-
with open(args.eagle_config) as f:
49-
eagle_config = json.load(f)
50-
if "draft_vocab_size" not in eagle_config:
51-
print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.")
52-
return
53-
5447
print("Calibrating vocab...")
5548
tokenizer = AutoTokenizer.from_pretrained(args.model)
5649
with open(args.data) as f:
@@ -59,7 +52,7 @@ def main():
5952
conversations = conversations[: args.calibrate_size]
6053
conversations = [item for sublist in conversations for item in sublist]
6154

62-
d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"])
55+
d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size)
6356
model_name = os.path.basename(os.path.normpath(args.model))
6457
vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt")
6558
os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
{
2-
"draft_vocab_size": 32000
2+
"rope_scaling": {
3+
"factor": 32.0,
4+
"low_freq_factor": 1.0,
5+
"high_freq_factor": 4.0,
6+
"original_max_position_embeddings": 8192,
7+
"rope_type": "llama3"
8+
},
9+
"initializer_range": 0.02
310
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
18+
import torch
19+
from transformers import AutoModelForCausalLM
20+
21+
import modelopt.torch.opt as mto
22+
from modelopt.torch.export import export_hf_checkpoint
23+
24+
25+
def parse_args():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--model_path", type=str, default="")
28+
parser.add_argument("--export_path", type=str, default="")
29+
return parser.parse_args()
30+
31+
32+
mto.enable_huggingface_checkpointing()
33+
34+
args = parse_args()
35+
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
36+
with torch.inference_mode():
37+
export_hf_checkpoint(
38+
model, # The quantized model.
39+
export_dir=args.export_path, # The directory where the exported files will be stored.
40+
)
41+
print(f"Exported checkpoint to {args.export_path}")

examples/speculative_decoding/launch.sh

Lines changed: 0 additions & 158 deletions
This file was deleted.

examples/speculative_decoding/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
import modelopt.torch.speculative as mtsp
4848
from modelopt.torch.utils import print_rank_0
4949

50+
try:
51+
import wandb
52+
53+
wandb.init()
54+
except ImportError:
55+
wandb = None
56+
5057
torch.manual_seed(0)
5158
mto.enable_huggingface_checkpointing()
5259

@@ -170,6 +177,8 @@ def train():
170177
{
171178
"hidden_size": model.config.hidden_size,
172179
"vocab_size": model.config.vocab_size,
180+
# we also overwrite max_pos_embedding for deployment compatibility
181+
"max_position_embeddings": model.config.max_position_embeddings,
173182
"draft_vocab_size": custom_config["draft_vocab_size"]
174183
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
175184
else model.config.vocab_size,
@@ -213,6 +222,8 @@ def on_step_end(self, args, state, control, **kwargs):
213222
device=kwargs["model"].device,
214223
)
215224
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
225+
if wandb:
226+
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
216227
return control
217228

218229
trainer = Trainer(

examples/speculative_decoding/server_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
parser.add_argument(
4747
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
4848
)
49-
parser.add_argument("--chat", action="store_true", help="Use chat mode")
49+
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
5050
parser.add_argument("--model", type=str, default="model", help="Model name")
5151
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
5252
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/bin/bash
2+
3+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
set -eo pipefail
19+
20+
# Set default values for BASE_MODEL, NUM_GPU, and DATA
21+
BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
22+
NUM_GPU=1
23+
DATA=Daring-Anteater/train.jsonl
24+
25+
# Parse input arguments --base-model, --num_gpu, and --data
26+
while [[ $# -gt 0 ]]; do
27+
key="$1"
28+
case $key in
29+
--base_model)
30+
BASE_MODEL="$2"
31+
shift; shift
32+
;;
33+
--num_gpu)
34+
NUM_GPU="$2"
35+
shift; shift
36+
;;
37+
--data)
38+
DATA="$2"
39+
shift; shift
40+
;;
41+
*)
42+
echo "Unknown argument: $1"
43+
exit 1
44+
;;
45+
esac
46+
done
47+
48+
49+
if [[ "$NUM_GPU" == 1 ]]; then
50+
export CUDA_VISIBLE_DEVICES=0
51+
else
52+
# Export as 0,1,...,N-1 for NUM_GPU GPUs
53+
export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPU-1)))
54+
fi
55+
56+
MODEL_BASENAME=$(basename "$BASE_MODEL")
57+
58+
echo "==== [1/3] Training draft model ===="
59+
OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
60+
./launch_train.sh --model $BASE_MODEL \
61+
--output_dir $OUTPUT_DIR \
62+
--data $DATA \
63+
--num_gpu $NUM_GPU \
64+
--num_epochs 2 \
65+
--eagle_config eagle_config.json
66+
67+
echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ===="
68+
python ar_validate.py --model_path $OUTPUT_DIR
69+
70+
echo "==== [3/3] Exporting checkpoint to deployment format ===="
71+
EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
72+
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH

modelopt/torch/export/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@
1919

2020
with import_plugin("megatron_importer"):
2121
from .megatron_importer import *
22+
with import_plugin("transformers"):
23+
from .hf_spec_export import *

0 commit comments

Comments
 (0)