Skip to content

Commit a11ace9

Browse files
author
Your Name
committed
minor
1 parent 78f3aaf commit a11ace9

File tree

1 file changed

+31
-97
lines changed

1 file changed

+31
-97
lines changed

examples/llm_qat/launch.sh

Lines changed: 31 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -18,100 +18,37 @@ set -eo pipefail
1818

1919
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
2020

21+
# Helper function to parse a single argument value
22+
parse_value() {
23+
if [[ "$1" != *=* ]]; then shift; fi
24+
echo "${1#*=}"
25+
}
26+
2127
while [ $# -gt 0 ]; do
2228
case "$1" in
23-
--model*)
24-
if [[ "$1" != *=* ]]; then shift; fi
25-
MODEL="${1#*=}"
26-
;;
27-
--output_dir*)
28-
if [[ "$1" != *=* ]]; then shift; fi
29-
OUTPUT_DIR="${1#*=}"
30-
;;
31-
--dataset*)
32-
if [[ "$1" != *=* ]]; then shift; fi
33-
DATASET="${1#*=}"
34-
;;
35-
--train_size*)
36-
if [[ "$1" != *=* ]]; then shift; fi
37-
TRAIN_SIZE="${1#*=}"
38-
;;
39-
--eval_size*)
40-
if [[ "$1" != *=* ]]; then shift; fi
41-
EVAL_SIZE="${1#*=}"
42-
;;
43-
--num_epochs*)
44-
if [[ "$1" != *=* ]]; then shift; fi
45-
NUM_EPOCHS="${1#*=}"
46-
;;
47-
--max_steps*)
48-
if [[ "$1" != *=* ]]; then shift; fi
49-
MAX_STEPS="${1#*=}"
50-
;;
51-
--save_steps*)
52-
if [[ "$1" != *=* ]]; then shift; fi
53-
SAVE_STEPS="${1#*=}"
54-
;;
55-
--accum_steps*)
56-
if [[ "$1" != *=* ]]; then shift; fi
57-
ACCUM_STEPS="${1#*=}"
58-
;;
59-
--lr*)
60-
if [[ "$1" != *=* ]]; then shift; fi
61-
LR="${1#*=}"
62-
;;
63-
--quant_cfg*)
64-
if [[ "$1" != *=* ]]; then shift; fi
65-
QUANT_CFG="${1#*=}"
66-
;;
67-
--compress*)
68-
if [[ "$1" != *=* ]]; then shift; fi
69-
COMPRESS="${1#*=}"
70-
;;
71-
--calib_size*)
72-
if [[ "$1" != *=* ]]; then shift; fi
73-
CALIB_SIZE="${1#*=}"
74-
;;
75-
--train_bs*)
76-
if [[ "$1" != *=* ]]; then shift; fi
77-
TRAIN_BS="${1#*=}"
78-
;;
79-
--eval_bs*)
80-
if [[ "$1" != *=* ]]; then shift; fi
81-
EVAL_BS="${1#*=}"
82-
;;
83-
--do_train*)
84-
if [[ "$1" != *=* ]]; then shift; fi
85-
DO_TRAIN="${1#*=}"
86-
;;
87-
--lora*)
88-
if [[ "$1" != *=* ]]; then shift; fi
89-
LORA="${1#*=}"
90-
;;
91-
--teacher_model*)
92-
if [[ "$1" != *=* ]]; then shift; fi
93-
TEACHER_MODEL="${1#*=}"
94-
;;
95-
--distill*)
96-
if [[ "$1" != *=* ]]; then shift; fi
97-
DISTILL="${1#*=}"
98-
;;
99-
--fsdp_transformer_layer_cls_to_wrap*)
100-
if [[ "$1" != *=* ]]; then shift; fi
101-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
102-
;;
103-
--max_seq_length*)
104-
if [[ "$1" != *=* ]]; then shift; fi
105-
MAX_SEQ_LENGTH="${1#*=}"
106-
;;
107-
--backend*)
108-
if [[ "$1" != *=* ]]; then shift; fi
109-
BACKEND="${1#*=}"
110-
;;
111-
--use_fsdp2*)
112-
if [[ "$1" != *=* ]]; then shift; fi
113-
USE_FSDP2="${1#*=}"
114-
;;
29+
--model*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
30+
--output_dir*) OUTPUT_DIR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
31+
--dataset*) DATASET=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
32+
--train_size*) TRAIN_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
33+
--eval_size*) EVAL_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
34+
--num_epochs*) NUM_EPOCHS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
35+
--max_steps*) MAX_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
36+
--save_steps*) SAVE_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
37+
--accum_steps*) ACCUM_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
38+
--lr*) LR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
39+
--quant_cfg*) QUANT_CFG=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
40+
--compress*) COMPRESS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
41+
--calib_size*) CALIB_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
42+
--train_bs*) TRAIN_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
43+
--eval_bs*) EVAL_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
44+
--do_train*) DO_TRAIN=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
45+
--lora*) LORA=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
46+
--teacher_model*) TEACHER_MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
47+
--distill*) DISTILL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
48+
--fsdp_transformer_layer_cls_to_wrap*) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
49+
--max_seq_length*) MAX_SEQ_LENGTH=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
50+
--backend*) BACKEND=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
51+
--use_fsdp2*) USE_FSDP2=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
11552
*)
11653
>&2 printf "Error: Invalid argument ${1#*=}\n"
11754
exit 1
@@ -175,23 +112,19 @@ case "${BACKEND,,}" in
175112
"fsdp1"|"fsdp")
176113
CONFIG_FILE="fsdp1.yaml"
177114
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
178-
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
179115
;;
180116
"fsdp2")
181117
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
182118
CONFIG_FILE="fsdp2.yaml"
183119
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
184-
GRADIENT_CHECKPOINTING_ARGS=""
185120
;;
186121
"ddp")
187122
CONFIG_FILE="ddp.yaml"
188123
FSDP_ARGS=""
189-
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
190124
;;
191125
"deepspeed")
192126
CONFIG_FILE="deepspeed.yaml"
193127
FSDP_ARGS=""
194-
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
195128
;;
196129
*)
197130
echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
@@ -239,7 +172,8 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
239172
--report_to tensorboard \
240173
--lora $LORA \
241174
--compress $COMPRESS \
242-
$QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS
175+
--gradient_checkpointing True \
176+
$QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
243177
"
244178

245179
start_time=$(date +%s)

0 commit comments

Comments
 (0)