|
| 1 | +#!/bin/bash |
| 2 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +set -eo pipefail |
| 18 | + |
| 19 | +while [ $# -gt 0 ]; do |
| 20 | + case "$1" in |
| 21 | + --training_seq_len*) |
| 22 | + if [[ "$1" != *=* ]]; then shift; fi |
| 23 | + TRAINING_SEQ_LEN="${1#*=}" |
| 24 | + ;; |
| 25 | + --model*) |
| 26 | + if [[ "$1" != *=* ]]; then shift; fi |
| 27 | + MODEL="${1#*=}" |
| 28 | + ;; |
| 29 | + --data*) |
| 30 | + if [[ "$1" != *=* ]]; then shift; fi |
| 31 | + DATA="${1#*=}" |
| 32 | + ;; |
| 33 | + --mode*) |
| 34 | + if [[ "$1" != *=* ]]; then shift; fi |
| 35 | + MODE="${1#*=}" |
| 36 | + ;; |
| 37 | + --output_dir*) |
| 38 | + if [[ "$1" != *=* ]]; then shift; fi |
| 39 | + OUTPUT_DIR="${1#*=}" |
| 40 | + ;; |
| 41 | + --num_epochs*) |
| 42 | + if [[ "$1" != *=* ]]; then shift; fi |
| 43 | + NUM_EPOCHS="${1#*=}" |
| 44 | + ;; |
| 45 | + --save_steps*) |
| 46 | + if [[ "$1" != *=* ]]; then shift; fi |
| 47 | + SAVE_STEPS="${1#*=}" |
| 48 | + ;; |
| 49 | + --lr*) |
| 50 | + if [[ "$1" != *=* ]]; then shift; fi |
| 51 | + LR="${1#*=}" |
| 52 | + ;; |
| 53 | + --train_bs*) |
| 54 | + if [[ "$1" != *=* ]]; then shift; fi |
| 55 | + TRAIN_BS="${1#*=}" |
| 56 | + ;; |
| 57 | + --medusa_num_heads*) |
| 58 | + if [[ "$1" != *=* ]]; then shift; fi |
| 59 | + MEDUSA_NUM_HEADS="${1#*=}" |
| 60 | + ;; |
| 61 | + --medusa_num_layers*) |
| 62 | + if [[ "$1" != *=* ]]; then shift; fi |
| 63 | + MEDUSA_NUM_LAYERS="${1#*=}" |
| 64 | + ;; |
| 65 | + --eagle_config*) |
| 66 | + if [[ "$1" != *=* ]]; then shift; fi |
| 67 | + EAGLE_CONFIG="${1#*=}" |
| 68 | + ;; |
| 69 | + --fsdp_transformer_layer_cls_to_wrap*) |
| 70 | + if [[ "$1" != *=* ]]; then shift; fi |
| 71 | + FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}" |
| 72 | + ;; |
| 73 | + --num_gpu*) |
| 74 | + if [[ "$1" != *=* ]]; then shift; fi |
| 75 | + NUM_GPU="${1#*=}" |
| 76 | + ;; |
| 77 | + *) |
| 78 | + >&2 printf "Error: Invalid argument ${1#*=}\n" |
| 79 | + exit 1 |
| 80 | + ;; |
| 81 | + esac |
| 82 | + shift |
| 83 | +done |
| 84 | + |
| 85 | +set -x |
| 86 | + |
| 87 | +# Get the default value for save_steps based on the available number of GPUs |
| 88 | +GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") |
| 89 | +# Calculate save_steps |
| 90 | +DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) |
| 91 | + |
| 92 | +MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} |
| 93 | +MODE=${MODE:-"eagle3"} |
| 94 | +# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path |
| 95 | +MODEL_BASENAME=$(basename "$MODEL") |
| 96 | +OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"} |
| 97 | +NUM_EPOCHS=${NUM_EPOCHS:-1} |
| 98 | +SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} |
| 99 | +LR=${LR:-"1e-4"} |
| 100 | +TRAIN_BS=${TRAIN_BS:-4} |
| 101 | +MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} |
| 102 | +MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} |
| 103 | +REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1} |
| 104 | +REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1} |
| 105 | +FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} |
| 106 | +NUM_GPU=${NUM_GPU:-1} |
| 107 | +TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512} |
| 108 | + |
| 109 | +if [[ "$MODE" == "medusa" ]]; then |
| 110 | + SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" |
| 111 | +elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then |
| 112 | + if [[ -n "$EAGLE_CONFIG" ]]; then |
| 113 | + SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" |
| 114 | + else |
| 115 | + SPECULATIVE_ARGS="" |
| 116 | + fi |
| 117 | +else |
| 118 | + echo "Only medusa, eagle1, eagle3 supported for now!" |
| 119 | + exit 1 |
| 120 | +fi |
| 121 | + |
| 122 | +if [[ "$NUM_GPU" == 1 ]]; then |
| 123 | + MULTI_GPU="" |
| 124 | +else |
| 125 | + MULTI_GPU="--multi_gpu" |
| 126 | +fi |
| 127 | + |
| 128 | +# Disable tokenizers parallelism to avoid warning |
| 129 | +export TOKENIZERS_PARALLELISM=False |
| 130 | +CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ |
| 131 | + --mode $MODE \ |
| 132 | + --model_name_or_path $MODEL \ |
| 133 | + --training_seq_len $TRAINING_SEQ_LEN \ |
| 134 | + --dataloader_drop_last True \ |
| 135 | + --bf16 True \ |
| 136 | + --output_dir $OUTPUT_DIR \ |
| 137 | + --num_train_epochs $NUM_EPOCHS \ |
| 138 | + --per_device_train_batch_size $TRAIN_BS \ |
| 139 | + --per_device_eval_batch_size $TRAIN_BS \ |
| 140 | + --gradient_accumulation_steps 1 \ |
| 141 | + --do_eval False \ |
| 142 | + --eval_accumulation_steps 1 \ |
| 143 | + --save_strategy steps \ |
| 144 | + --save_steps $SAVE_STEPS \ |
| 145 | + --learning_rate $LR \ |
| 146 | + --weight_decay 0.0 \ |
| 147 | + --warmup_steps 100 \ |
| 148 | + --lr_scheduler_type linear \ |
| 149 | + --logging_steps 100 \ |
| 150 | + --tf32 True \ |
| 151 | + --data_path $DATA \ |
| 152 | + $SPECULATIVE_ARGS |
| 153 | +" |
| 154 | + |
| 155 | +start_time=$(date +%s) |
| 156 | +sh -c "$CMD" |
| 157 | +echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" |
0 commit comments