Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 149 additions & 134 deletions examples/speculative_decoding/README.md

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions examples/speculative_decoding/SLURM_prepare_data.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SLURM Prepare Data

For basic parallelization of synthetic data generation we provide some SLURM support.
Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable.

Example of allocating 4 nodes for 120 minutes

```sh
salloc -N4 -A <account> -p <partition> -J <account>-synthetic:data-gen -t 120
```

Create shards of some given size

```sh
python3 distributed_generate/sharding_utils.py --input_path /data/train.jsonl --output_dir /data/train/ --max_lines_per_shard 10000
```

Run workers on SLURM

```sh
bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 0 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""
```

`/scripts/` is the absolute path to `modelopt/examples/speculative_decoding` which contains `server_generate.py` and `distributed_generate`.
This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples).
In this case, the first 40 shards of data will be processed.
To process the next 40 shards

```sh
bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4
```

To combine the shards back

```sh
python3 distributed_generate/sharding_utils.py --input_dir /data/output/ --output_path /data/output.jsonl --combine
```
8 changes: 4 additions & 4 deletions examples/speculative_decoding/ar_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
mto.enable_huggingface_checkpointing()


def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None):
def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=80, device=None):
validator = HFARValidation(model, tokenizer)
num_samples = min(num_samples, len(ds))
ars = []
Expand Down Expand Up @@ -54,12 +54,12 @@ def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=No
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
parser.add_argument("--steps", type=int, default=1, help="Steps for AR validation")
parser.add_argument("--steps", type=int, default=3, help="Steps for AR validation")
parser.add_argument(
"--osl", type=int, default=100, help="Output sequence length for AR validation"
"--osl", type=int, default=32, help="Output sequence length for AR validation"
)
parser.add_argument(
"--num_samples", type=int, default=20, help="Number of MT-Bench samples to use"
"--num_samples", type=int, default=80, help="Number of MT-Bench samples to use"
)
parser.add_argument(
"--ar_lower_bound",
Expand Down
15 changes: 4 additions & 11 deletions examples/speculative_decoding/calibrate_draft_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ def main():
parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer")
parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)")
parser.add_argument(
"--eagle_config",
type=str,
"--draft_vocab_size",
type=int,
required=True,
default="eagle_config.json",
help="Path to eagle_config.json",
help="Draft vocab size",
)
parser.add_argument(
"--calibrate_size",
Expand All @@ -45,12 +44,6 @@ def main():
)
args = parser.parse_args()

with open(args.eagle_config) as f:
eagle_config = json.load(f)
if "draft_vocab_size" not in eagle_config:
print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.")
return

print("Calibrating vocab...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
with open(args.data) as f:
Expand All @@ -59,7 +52,7 @@ def main():
conversations = conversations[: args.calibrate_size]
conversations = [item for sublist in conversations for item in sublist]

d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"])
d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size)
model_name = os.path.basename(os.path.normpath(args.model))
vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt")
os.makedirs(os.path.dirname(vocab_path), exist_ok=True)
Expand Down
9 changes: 8 additions & 1 deletion examples/speculative_decoding/eagle_config.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
{
"draft_vocab_size": 32000
"rope_scaling": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"initializer_range": 0.02
}
48 changes: 48 additions & 0 deletions examples/speculative_decoding/export_hf_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Export a HF checkpoint (with ModelOpt state) for deployment."""

import argparse

import torch
from transformers import AutoModelForCausalLM

import modelopt.torch.opt as mto
from modelopt.torch.export import export_hf_checkpoint


def parse_args():
parser = argparse.ArgumentParser(
description="Export a HF checkpoint (with ModelOpt state) for deployment."
)
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
parser.add_argument(
"--export_path", type=str, default="Destination directory for exported files."
)
return parser.parse_args()


mto.enable_huggingface_checkpointing()

args = parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
model.eval()
with torch.inference_mode():
export_hf_checkpoint(
model, # The quantized model.
export_dir=args.export_path, # The directory where the exported files will be stored.
)
print(f"Exported checkpoint to {args.export_path}")
157 changes: 157 additions & 0 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -eo pipefail

while [ $# -gt 0 ]; do
case "$1" in
--training_seq_len*)
if [[ "$1" != *=* ]]; then shift; fi
TRAINING_SEQ_LEN="${1#*=}"
;;
--model*)
if [[ "$1" != *=* ]]; then shift; fi
MODEL="${1#*=}"
;;
--data*)
if [[ "$1" != *=* ]]; then shift; fi
DATA="${1#*=}"
;;
--mode*)
if [[ "$1" != *=* ]]; then shift; fi
MODE="${1#*=}"
;;
--output_dir*)
if [[ "$1" != *=* ]]; then shift; fi
OUTPUT_DIR="${1#*=}"
;;
--num_epochs*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_EPOCHS="${1#*=}"
;;
--save_steps*)
if [[ "$1" != *=* ]]; then shift; fi
SAVE_STEPS="${1#*=}"
;;
--lr*)
if [[ "$1" != *=* ]]; then shift; fi
LR="${1#*=}"
;;
--train_bs*)
if [[ "$1" != *=* ]]; then shift; fi
TRAIN_BS="${1#*=}"
;;
--medusa_num_heads*)
if [[ "$1" != *=* ]]; then shift; fi
MEDUSA_NUM_HEADS="${1#*=}"
;;
--medusa_num_layers*)
if [[ "$1" != *=* ]]; then shift; fi
MEDUSA_NUM_LAYERS="${1#*=}"
;;
--eagle_config*)
if [[ "$1" != *=* ]]; then shift; fi
EAGLE_CONFIG="${1#*=}"
;;
--fsdp_transformer_layer_cls_to_wrap*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
;;
--num_gpu*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_GPU="${1#*=}"
;;
Comment on lines +69 to +76
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Parsed FSDP arg is never forwarded to the training command.

--fsdp_transformer_layer_cls_to_wrap is parsed but not passed to main.py, so user input is ignored.

@@
     --num_gpu*)
       if [[ "$1" != *=* ]]; then shift; fi
       NUM_GPU="${1#*=}"
       ;;
@@
 fi
@@
 export TOKENIZERS_PARALLELISM=False
+FSDP_ARGS=""
+if [[ -n "${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-}" ]]; then
+  FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap ${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP}"
+fi
 CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
@@
     --data_path $DATA \
-    $SPECULATIVE_ARGS
+    $SPECULATIVE_ARGS \
+    $FSDP_ARGS
 "

Also applies to: 130-153

*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
;;
esac
shift
done

set -x

# Get the default value for save_steps based on the available number of GPUs
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
# Calculate save_steps
DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))

MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
MODE=${MODE:-"eagle3"}
# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path
MODEL_BASENAME=$(basename "$MODEL")
OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
NUM_EPOCHS=${NUM_EPOCHS:-1}
SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
LR=${LR:-"1e-4"}
TRAIN_BS=${TRAIN_BS:-4}
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
NUM_GPU=${NUM_GPU:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512}

if [[ "$MODE" == "medusa" ]]; then
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then
if [[ -n "$EAGLE_CONFIG" ]]; then
SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
else
SPECULATIVE_ARGS=""
fi
else
echo "Only medusa, eagle1, eagle3 supported for now!"
exit 1
fi

if [[ "$NUM_GPU" == 1 ]]; then
MULTI_GPU=""
else
MULTI_GPU="--multi_gpu"
fi

Comment on lines +122 to +127
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

❓ Verification inconclusive

Multi-GPU launch: pass explicit process count to Accelerate; --multi_gpu may be ignored on newer versions.

Safer to specify --num_processes "${NUM_GPU}" and drop the custom flag.

-if [[ "$NUM_GPU" == 1 ]]; then
-  MULTI_GPU=""
-else
-  MULTI_GPU="--multi_gpu"
-fi
+LAUNCH_OPTS="--mixed_precision bf16"
+if [[ "${NUM_GPU}" -gt 1 ]]; then
+  LAUNCH_OPTS+=" --num_processes ${NUM_GPU}"
+fi
@@
-CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+CMD="accelerate launch $LAUNCH_OPTS main.py \

Also applies to: 130-131


Use explicit --num_processes instead of --multi_gpu for multi-GPU runs

Accelerate launch supports --num_processes=<N> alone to spawn N GPUs (and implicitly use MULTI_GPU) without requiring --multi_gpu (huggingface.co, modeldatabase.com)

-if [[ "$NUM_GPU" == 1 ]]; then
-  MULTI_GPU=""
-else
-  MULTI_GPU="--multi_gpu"
-fi
+LAUNCH_OPTS="--mixed_precision bf16"
+if [[ "${NUM_GPU}" -gt 1 ]]; then
+  LAUNCH_OPTS+=" --num_processes ${NUM_GPU}"
+fi
@@
-CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+CMD="accelerate launch $LAUNCH_OPTS main.py \

Also update the same pattern at lines 130–131.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/speculative_decoding/launch_train.sh around lines 122–127 (and also
update the same pattern at lines 130–131), replace the current multi-GPU flag
logic that sets MULTI_GPU="--multi_gpu" with an explicit process-count flag:
when NUM_GPU==1 keep MULTI_GPU empty, otherwise set
MULTI_GPU="--num_processes=$NUM_GPU"; update any subsequent invocations that
previously relied on --multi_gpu to use this MULTI_GPU variable so Accelerate is
launched with --num_processes=<N> instead of --multi_gpu.

# Disable tokenizers parallelism to avoid warning
export TOKENIZERS_PARALLELISM=False
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
--mode $MODE \
--model_name_or_path $MODEL \
--training_seq_len $TRAINING_SEQ_LEN \
--dataloader_drop_last True \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs $NUM_EPOCHS \
--per_device_train_batch_size $TRAIN_BS \
--per_device_eval_batch_size $TRAIN_BS \
--gradient_accumulation_steps 1 \
--do_eval False \
--eval_accumulation_steps 1 \
--save_strategy steps \
--save_steps $SAVE_STEPS \
--learning_rate $LR \
--weight_decay 0.0 \
--warmup_steps 100 \
--lr_scheduler_type linear \
--logging_steps 100 \
--tf32 True \
--data_path $DATA \
$SPECULATIVE_ARGS
"

start_time=$(date +%s)
sh -c "$CMD"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
11 changes: 11 additions & 0 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
import modelopt.torch.speculative as mtsp
from modelopt.torch.utils import print_rank_0

try:
import wandb

wandb.init()
except ImportError:
wandb = None

torch.manual_seed(0)
mto.enable_huggingface_checkpointing()

Expand Down Expand Up @@ -170,6 +177,8 @@ def train():
{
"hidden_size": model.config.hidden_size,
"vocab_size": model.config.vocab_size,
# we also overwrite max_pos_embedding for deployment compatibility
"max_position_embeddings": model.config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else model.config.vocab_size,
Expand Down Expand Up @@ -213,6 +222,8 @@ def on_step_end(self, args, state, control, **kwargs):
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
return control

trainer = Trainer(
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative_decoding/server_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
parser.add_argument(
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
)
parser.add_argument("--chat", action="store_true", help="Use chat mode")
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
parser.add_argument("--model", type=str, default="model", help="Model name")
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")
Expand Down
Loading
Loading