Skip to content

Commit 52dbdec

Browse files
committed
add file; only d2t when training
Signed-off-by: h-guo18 <[email protected]>
1 parent 47a0a50 commit 52dbdec

File tree

2 files changed

+158
-1
lines changed

2 files changed

+158
-1
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#!/bin/bash
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
set -eo pipefail
18+
19+
while [ $# -gt 0 ]; do
20+
case "$1" in
21+
--training_seq_len*)
22+
if [[ "$1" != *=* ]]; then shift; fi
23+
TRAINING_SEQ_LEN="${1#*=}"
24+
;;
25+
--model*)
26+
if [[ "$1" != *=* ]]; then shift; fi
27+
MODEL="${1#*=}"
28+
;;
29+
--data*)
30+
if [[ "$1" != *=* ]]; then shift; fi
31+
DATA="${1#*=}"
32+
;;
33+
--mode*)
34+
if [[ "$1" != *=* ]]; then shift; fi
35+
MODE="${1#*=}"
36+
;;
37+
--output_dir*)
38+
if [[ "$1" != *=* ]]; then shift; fi
39+
OUTPUT_DIR="${1#*=}"
40+
;;
41+
--num_epochs*)
42+
if [[ "$1" != *=* ]]; then shift; fi
43+
NUM_EPOCHS="${1#*=}"
44+
;;
45+
--save_steps*)
46+
if [[ "$1" != *=* ]]; then shift; fi
47+
SAVE_STEPS="${1#*=}"
48+
;;
49+
--lr*)
50+
if [[ "$1" != *=* ]]; then shift; fi
51+
LR="${1#*=}"
52+
;;
53+
--train_bs*)
54+
if [[ "$1" != *=* ]]; then shift; fi
55+
TRAIN_BS="${1#*=}"
56+
;;
57+
--medusa_num_heads*)
58+
if [[ "$1" != *=* ]]; then shift; fi
59+
MEDUSA_NUM_HEADS="${1#*=}"
60+
;;
61+
--medusa_num_layers*)
62+
if [[ "$1" != *=* ]]; then shift; fi
63+
MEDUSA_NUM_LAYERS="${1#*=}"
64+
;;
65+
--eagle_config*)
66+
if [[ "$1" != *=* ]]; then shift; fi
67+
EAGLE_CONFIG="${1#*=}"
68+
;;
69+
--fsdp_transformer_layer_cls_to_wrap*)
70+
if [[ "$1" != *=* ]]; then shift; fi
71+
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
72+
;;
73+
--num_gpu*)
74+
if [[ "$1" != *=* ]]; then shift; fi
75+
NUM_GPU="${1#*=}"
76+
;;
77+
*)
78+
>&2 printf "Error: Invalid argument ${1#*=}\n"
79+
exit 1
80+
;;
81+
esac
82+
shift
83+
done
84+
85+
set -x
86+
87+
# Get the default value for save_steps based on the available number of GPUs
88+
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
89+
# Calculate save_steps
90+
DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
91+
92+
MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
93+
MODE=${MODE:-"eagle3"}
94+
# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path
95+
MODEL_BASENAME=$(basename "$MODEL")
96+
OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"}
97+
NUM_EPOCHS=${NUM_EPOCHS:-1}
98+
SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS}
99+
LR=${LR:-"1e-4"}
100+
TRAIN_BS=${TRAIN_BS:-4}
101+
MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1}
102+
MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1}
103+
REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
104+
REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
105+
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
106+
NUM_GPU=${NUM_GPU:-1}
107+
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512}
108+
109+
if [[ "$MODE" == "medusa" ]]; then
110+
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
111+
elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then
112+
if [[ -n "$EAGLE_CONFIG" ]]; then
113+
SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG"
114+
else
115+
SPECULATIVE_ARGS=""
116+
fi
117+
else
118+
echo "Only medusa, eagle1, eagle3 supported for now!"
119+
exit 1
120+
fi
121+
122+
if [[ "$NUM_GPU" == 1 ]]; then
123+
MULTI_GPU=""
124+
else
125+
MULTI_GPU="--multi_gpu"
126+
fi
127+
128+
# Disable tokenizers parallelism to avoid warning
129+
export TOKENIZERS_PARALLELISM=False
130+
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
131+
--mode $MODE \
132+
--model_name_or_path $MODEL \
133+
--training_seq_len $TRAINING_SEQ_LEN \
134+
--dataloader_drop_last True \
135+
--bf16 True \
136+
--output_dir $OUTPUT_DIR \
137+
--num_train_epochs $NUM_EPOCHS \
138+
--per_device_train_batch_size $TRAIN_BS \
139+
--per_device_eval_batch_size $TRAIN_BS \
140+
--gradient_accumulation_steps 1 \
141+
--do_eval False \
142+
--eval_accumulation_steps 1 \
143+
--save_strategy steps \
144+
--save_steps $SAVE_STEPS \
145+
--learning_rate $LR \
146+
--weight_decay 0.0 \
147+
--warmup_steps 100 \
148+
--lr_scheduler_type linear \
149+
--logging_steps 100 \
150+
--tf32 True \
151+
--data_path $DATA \
152+
$SPECULATIVE_ARGS
153+
"
154+
155+
start_time=$(date +%s)
156+
sh -c "$CMD"
157+
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def _base_model_forward(
713713
base_model_loss = loss_fct(loss_logits, labels)
714714

715715
# Map the base model logits to the draft vocab
716-
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
716+
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training:
717717
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
718718

719719
return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values

0 commit comments

Comments
 (0)