Skip to content

Commit c79e8e2

Browse files
committed
update eagle3 example; add export
1 parent 8a07376 commit c79e8e2

File tree

14 files changed

+696
-349
lines changed

14 files changed

+696
-349
lines changed

examples/speculative_decoding/README.md

Lines changed: 150 additions & 132 deletions
Large diffs are not rendered by default.

examples/speculative_decoding/ar_validate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828

29-
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None):
29+
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
3030
validator = HFARValidation(model, tokenizer)
3131
num_samples = min(num_samples, len(ds))
3232
ars = []
@@ -54,12 +54,12 @@ def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=No
5454
def main():
5555
parser = argparse.ArgumentParser()
5656
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
57-
parser.add_argument("--steps", type=int, default=1, help="Steps for AR validation")
57+
parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation")
5858
parser.add_argument(
59-
"--osl", type=int, default=100, help="Output sequence length for AR validation"
59+
"--osl", type=int, default=32, help="Output sequence length for AR validation"
6060
)
6161
parser.add_argument(
62-
"--num_samples", type=int, default=20, help="Number of MT-Bench samples to use"
62+
"--num_samples", type=int, default=80, help="Number of MT-Bench samples to use"
6363
)
6464
parser.add_argument(
6565
"--ar_lower_bound",

examples/speculative_decoding/calibrate_draft_vocab.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ def main():
2828
parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer")
2929
parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)")
3030
parser.add_argument(
31-
"--eagle_config",
32-
type=str,
31+
"--draft_vocab_size",
32+
type=int,
3333
required=True,
34-
default="eagle_config.json",
35-
help="Path to eagle_config.json",
34+
help="Draft vocab size",
3635
)
3736
parser.add_argument(
3837
"--calibrate_size",
@@ -45,12 +44,6 @@ def main():
4544
)
4645
args = parser.parse_args()
4746

48-
with open(args.eagle_config) as f:
49-
eagle_config = json.load(f)
50-
if "draft_vocab_size" not in eagle_config:
51-
print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.")
52-
return
53-
5447
print("Calibrating vocab...")
5548
tokenizer = AutoTokenizer.from_pretrained(args.model)
5649
with open(args.data) as f:
@@ -59,7 +52,7 @@ def main():
5952
conversations = conversations[: args.calibrate_size]
6053
conversations = [item for sublist in conversations for item in sublist]
6154

62-
d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"])
55+
d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size)
6356
model_name = os.path.basename(os.path.normpath(args.model))
6457
vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt")
6558
os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
{
2-
"draft_vocab_size": 32000
2+
"rope_scaling": {
3+
"factor": 32.0,
4+
"low_freq_factor": 1.0,
5+
"high_freq_factor": 4.0,
6+
"original_max_position_embeddings": 8192,
7+
"rope_type": "llama3"
8+
},
9+
"initializer_range": 0.02
310
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
18+
import torch
19+
from transformers import AutoModelForCausalLM
20+
21+
import modelopt.torch.opt as mto
22+
from modelopt.torch.export import export_hf_checkpoint
23+
24+
25+
def parse_args():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--model_path", type=str, default="")
28+
parser.add_argument("--export_path", type=str, default="")
29+
return parser.parse_args()
30+
31+
32+
mto.enable_huggingface_checkpointing()
33+
34+
args = parse_args()
35+
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
36+
with torch.inference_mode():
37+
export_hf_checkpoint(
38+
model, # The quantized model.
39+
export_dir=args.export_path, # The directory where the exported files will be stored.
40+
)
41+
print(f"Exported checkpoint to {args.export_path}")

examples/speculative_decoding/launch.sh

Lines changed: 0 additions & 158 deletions
This file was deleted.

examples/speculative_decoding/main.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import modelopt.torch.opt as mto
4747
import modelopt.torch.speculative as mtsp
4848
from modelopt.torch.utils import print_rank_0
49+
from modelopt.torch.utils.distributed import is_master
4950

5051
torch.manual_seed(0)
5152
mto.enable_huggingface_checkpointing()
@@ -170,6 +171,8 @@ def train():
170171
{
171172
"hidden_size": model.config.hidden_size,
172173
"vocab_size": model.config.vocab_size,
174+
# we also overwrite max_pos_embedding for deployment compatibility
175+
"max_position_embeddings": model.config.max_position_embeddings,
173176
"draft_vocab_size": custom_config["draft_vocab_size"]
174177
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
175178
else model.config.vocab_size,
@@ -202,6 +205,15 @@ def train():
202205
class ARValidationCallback(TrainerCallback):
203206
def __init__(self, ar_validate_steps: int = 500):
204207
self.ar_validate_steps = ar_validate_steps
208+
self.wandb = None
209+
if is_master():
210+
try:
211+
import wandb
212+
213+
self.wandb = wandb
214+
self.wandb.init()
215+
except ImportError:
216+
pass
205217

206218
def on_step_end(self, args, state, control, **kwargs):
207219
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
@@ -213,6 +225,8 @@ def on_step_end(self, args, state, control, **kwargs):
213225
device=kwargs["model"].device,
214226
)
215227
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
228+
if self.wandb:
229+
self.wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
216230
return control
217231

218232
trainer = Trainer(

examples/speculative_decoding/server_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
parser.add_argument(
4747
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
4848
)
49-
parser.add_argument("--chat", action="store_true", help="Use chat mode")
49+
parser.add_argument("--chat", action="store_true", default=True, help="Use chat mode")
5050
parser.add_argument("--model", type=str, default="model", help="Model name")
5151
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
5252
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")

0 commit comments

Comments
 (0)