Skip to content

Commit a6fa34c

Browse files
authored
Feat: update eagle3 example; add export (#293)
Signed-off-by: h-guo18 <[email protected]>
1 parent 512dbb7 commit a6fa34c

File tree

14 files changed

+714
-190
lines changed

14 files changed

+714
-190
lines changed

examples/speculative_decoding/README.md

Lines changed: 149 additions & 134 deletions
Large diffs are not rendered by default.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SLURM Prepare Data
2+
3+
For basic parallelization of synthetic data generation we provide some SLURM support.
4+
Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.
5+
6+
Example of allocating 4 nodes for 120 minutes
7+
8+
```sh
9+
salloc -N4 -A <account> -p <partition> -J <account>-synthetic:data-gen -t 120
10+
```
11+
12+
Create shards of some given size
13+
14+
```sh
15+
python3 distributed_generate/sharding_utils.py --input_path /data/train.jsonl --output_dir /data/train/ --max_lines_per_shard 10000
16+
```
17+
18+
Run workers on SLURM
19+
20+
```sh
21+
bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 0 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
22+
```
23+
24+
`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
25+
This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
26+
In this case, the first 40 shards of data will be processed.
27+
To process the next 40 shards
28+
29+
```sh
30+
bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
31+
```
32+
33+
To combine the shards back
34+
35+
```sh
36+
python3 distributed_generate/sharding_utils.py --input_dir /data/output/ --output_path /data/output.jsonl --combine
37+
```

examples/speculative_decoding/ar_validate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828

29-
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None):
29+
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
3030
validator = HFARValidation(model, tokenizer)
3131
num_samples = min(num_samples, len(ds))
3232
ars = []
@@ -54,12 +54,12 @@ def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=No
5454
def main():
5555
parser = argparse.ArgumentParser()
5656
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
57-
parser.add_argument("--steps", type=int, default=1, help="Steps for AR validation")
57+
parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation")
5858
parser.add_argument(
59-
"--osl", type=int, default=100, help="Output sequence length for AR validation"
59+
"--osl", type=int, default=32, help="Output sequence length for AR validation"
6060
)
6161
parser.add_argument(
62-
"--num_samples", type=int, default=20, help="Number of MT-Bench samples to use"
62+
"--num_samples", type=int, default=80, help="Number of MT-Bench samples to use"
6363
)
6464
parser.add_argument(
6565
"--ar_lower_bound",

examples/speculative_decoding/calibrate_draft_vocab.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ def main():
2828
parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer")
2929
parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)")
3030
parser.add_argument(
31-
"--eagle_config",
32-
type=str,
31+
"--draft_vocab_size",
32+
type=int,
3333
required=True,
34-
default="eagle_config.json",
35-
help="Path to eagle_config.json",
34+
help="Draft vocab size",
3635
)
3736
parser.add_argument(
3837
"--calibrate_size",
@@ -45,12 +44,6 @@ def main():
4544
)
4645
args = parser.parse_args()
4746

48-
with open(args.eagle_config) as f:
49-
eagle_config = json.load(f)
50-
if "draft_vocab_size" not in eagle_config:
51-
print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.")
52-
return
53-
5447
print("Calibrating vocab...")
5548
tokenizer = AutoTokenizer.from_pretrained(args.model)
5649
with open(args.data) as f:
@@ -59,7 +52,7 @@ def main():
5952
conversations = conversations[: args.calibrate_size]
6053
conversations = [item for sublist in conversations for item in sublist]
6154

62-
d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"])
55+
d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size)
6356
model_name = os.path.basename(os.path.normpath(args.model))
6457
vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt")
6558
os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
{
2-
"draft_vocab_size": 32000
2+
"rope_scaling": {
3+
"factor": 32.0,
4+
"low_freq_factor": 1.0,
5+
"high_freq_factor": 4.0,
6+
"original_max_position_embeddings": 8192,
7+
"rope_type": "llama3"
8+
},
9+
"initializer_range": 0.02
310
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Export a HF checkpoint (with ModelOpt state) for deployment."""
17+
18+
import argparse
19+
20+
import torch
21+
from transformers import AutoModelForCausalLM
22+
23+
import modelopt.torch.opt as mto
24+
from modelopt.torch.export import export_hf_checkpoint
25+
26+
27+
def parse_args():
28+
parser = argparse.ArgumentParser(
29+
description="Export a HF checkpoint (with ModelOpt state) for deployment."
30+
)
31+
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
32+
parser.add_argument(
33+
"--export_path", type=str, default="Destination directory for exported files."
34+
)
35+
return parser.parse_args()
36+
37+
38+
mto.enable_huggingface_checkpointing()
39+
40+
args = parse_args()
41+
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
42+
model.eval()
43+
with torch.inference_mode():
44+
export_hf_checkpoint(
45+
model, # The quantized model.
46+
export_dir=args.export_path, # The directory where the exported files will be stored.
47+
)
48+
print(f"Exported checkpoint to {args.export_path}")
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#!/bin/bash
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
set -eo pipefail
18+
19+
while [ $# -gt 0 ]; do
20+
case "$1" in
21+
--training_seq_len*)
22+
if [[ "$1" != *=* ]]; then shift; fi
23+
TRAINING_SEQ_LEN="${1#*=}"
24+
;;
25+
--model*)
26+
if [[ "$1" != *=* ]]; then shift; fi
27+
MODEL="${1#*=}"
28+
;;
29+
--data*)
30+
if [[ "$1" != *=* ]]; then shift; fi
31+
DATA="${1#*=}"
32+
;;
33+
--mode*)
34+
if [[ "$1" != *=* ]]; then shift; fi
35+
MODE="${1#*=}"
36+
;;
37+
--output_dir*)
38+
if [[ "$1" != *=* ]]; then shift; fi
39+
OUTPUT_DIR="${1#*=}"
40+
;;
41+
--num_epochs*)
42+
if [[ "$1" != *=* ]]; then shift; fi
43+
NUM_EPOCHS="${1#*=}"
44+
;;
45+
--save_steps*)
46+
if [[ "$1" != *=* ]]; then shift; fi
47+
SAVE_STEPS="${1#*=}"
48+
;;
49+
--lr*)
50+
if [[ "$1" != *=* ]]; then shift; fi
51+
LR="${1#*=}"
52+
;;
53+
--train_bs*)
54+
if [[ "$1" != *=* ]]; then shift; fi
55+
TRAIN_BS="${1#*=}"
56+
;;
57+
--medusa_num_heads*)
58+
if [[ "$1" != *=* ]]; then shift; fi
59+
MEDUSA_NUM_HEADS="${1#*=}"
60+
;;
61+
--medusa_num_layers*)
62+
if [[ "$1" != *=* ]]; then shift; fi
63+
MEDUSA_NUM_LAYERS="${1#*=}"
64+
;;
65+
--eagle_config*)
66+
if [[ "$1" != *=* ]]; then shift; fi
67+
EAGLE_CONFIG="${1#*=}"
68+
;;
69+
--fsdp_transformer_layer_cls_to_wrap*)
70+
if [[ "$1" != *=* ]]; then shift; fi
71+
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
72+
;;
73+
--num_gpu*)
74+
if [[ "$1" != *=* ]]; then shift; fi
75+
NUM_GPU="${1#*=}"
76+
;;
77+
*)
78+
>&2 printf "Error: Invalid argument ${1#*=}\n"
79+
exit 1
80+
;;
81+
esac
82+
shift
83+
done
84+
85+
set -x
86+
87+
# Get the default value for save_steps based on the available number of GPUs
88+
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
89+
# Calculate save_steps
90+
DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
91+
92+
MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
93+
MODE=${MODE:-"eagle3"}
94+
# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path
95+
MODEL_BASENAME=$(basename "$MODEL")
96+
OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
97+
NUM_EPOCHS=${NUM_EPOCHS:-1}
98+
SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
99+
LR=${LR:-"1e-4"}
100+
TRAIN_BS=${TRAIN_BS:-4}
101+
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
102+
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
103+
REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
104+
REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
105+
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
106+
NUM_GPU=${NUM_GPU:-1}
107+
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512}
108+
109+
if [[ "$MODE" == "medusa" ]]; then
110+
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
111+
elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then
112+
if [[ -n "$EAGLE_CONFIG" ]]; then
113+
SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
114+
else
115+
SPECULATIVE_ARGS=""
116+
fi
117+
else
118+
echo "Only medusa, eagle1, eagle3 supported for now!"
119+
exit 1
120+
fi
121+
122+
if [[ "$NUM_GPU" == 1 ]]; then
123+
MULTI_GPU=""
124+
else
125+
MULTI_GPU="--multi_gpu"
126+
fi
127+
128+
# Disable tokenizers parallelism to avoid warning
129+
export TOKENIZERS_PARALLELISM=False
130+
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
131+
--mode $MODE \
132+
--model_name_or_path $MODEL \
133+
--training_seq_len $TRAINING_SEQ_LEN \
134+
--dataloader_drop_last True \
135+
--bf16 True \
136+
--output_dir $OUTPUT_DIR \
137+
--num_train_epochs $NUM_EPOCHS \
138+
--per_device_train_batch_size $TRAIN_BS \
139+
--per_device_eval_batch_size $TRAIN_BS \
140+
--gradient_accumulation_steps 1 \
141+
--do_eval False \
142+
--eval_accumulation_steps 1 \
143+
--save_strategy steps \
144+
--save_steps $SAVE_STEPS \
145+
--learning_rate $LR \
146+
--weight_decay 0.0 \
147+
--warmup_steps 100 \
148+
--lr_scheduler_type linear \
149+
--logging_steps 100 \
150+
--tf32 True \
151+
--data_path $DATA \
152+
$SPECULATIVE_ARGS
153+
"
154+
155+
start_time=$(date +%s)
156+
sh -c "$CMD"
157+
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"

examples/speculative_decoding/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
import modelopt.torch.speculative as mtsp
4848
from modelopt.torch.utils import print_rank_0
4949

50+
try:
51+
import wandb
52+
53+
wandb.init()
54+
except ImportError:
55+
wandb = None
56+
5057
torch.manual_seed(0)
5158
mto.enable_huggingface_checkpointing()
5259

@@ -170,6 +177,8 @@ def train():
170177
{
171178
"hidden_size": model.config.hidden_size,
172179
"vocab_size": model.config.vocab_size,
180+
# we also overwrite max_pos_embedding for deployment compatibility
181+
"max_position_embeddings": model.config.max_position_embeddings,
173182
"draft_vocab_size": custom_config["draft_vocab_size"]
174183
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
175184
else model.config.vocab_size,
@@ -213,6 +222,8 @@ def on_step_end(self, args, state, control, **kwargs):
213222
device=kwargs["model"].device,
214223
)
215224
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
225+
if wandb:
226+
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
216227
return control
217228

218229
trainer = Trainer(

examples/speculative_decoding/server_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
parser.add_argument(
4747
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
4848
)
49-
parser.add_argument("--chat", action="store_true", help="Use chat mode")
49+
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
5050
parser.add_argument("--model", type=str, default="model", help="Model name")
5151
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
5252
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")

0 commit comments

Comments
 (0)