@@ -18,100 +18,37 @@ set -eo pipefail
18
18
19
19
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
20
20
21
+ # Helper function to parse a single argument value
22
+ parse_value () {
23
+ if [[ " $1 " != * = * ]]; then shift ; fi
24
+ echo " ${1#* =} "
25
+ }
26
+
21
27
while [ $# -gt 0 ]; do
22
28
case " $1 " in
23
- --model* )
24
- if [[ " $1 " != * = * ]]; then shift ; fi
25
- MODEL=" ${1#* =} "
26
- ;;
27
- --output_dir* )
28
- if [[ " $1 " != * = * ]]; then shift ; fi
29
- OUTPUT_DIR=" ${1#* =} "
30
- ;;
31
- --dataset* )
32
- if [[ " $1 " != * = * ]]; then shift ; fi
33
- DATASET=" ${1#* =} "
34
- ;;
35
- --train_size* )
36
- if [[ " $1 " != * = * ]]; then shift ; fi
37
- TRAIN_SIZE=" ${1#* =} "
38
- ;;
39
- --eval_size* )
40
- if [[ " $1 " != * = * ]]; then shift ; fi
41
- EVAL_SIZE=" ${1#* =} "
42
- ;;
43
- --num_epochs* )
44
- if [[ " $1 " != * = * ]]; then shift ; fi
45
- NUM_EPOCHS=" ${1#* =} "
46
- ;;
47
- --max_steps* )
48
- if [[ " $1 " != * = * ]]; then shift ; fi
49
- MAX_STEPS=" ${1#* =} "
50
- ;;
51
- --save_steps* )
52
- if [[ " $1 " != * = * ]]; then shift ; fi
53
- SAVE_STEPS=" ${1#* =} "
54
- ;;
55
- --accum_steps* )
56
- if [[ " $1 " != * = * ]]; then shift ; fi
57
- ACCUM_STEPS=" ${1#* =} "
58
- ;;
59
- --lr* )
60
- if [[ " $1 " != * = * ]]; then shift ; fi
61
- LR=" ${1#* =} "
62
- ;;
63
- --quant_cfg* )
64
- if [[ " $1 " != * = * ]]; then shift ; fi
65
- QUANT_CFG=" ${1#* =} "
66
- ;;
67
- --compress* )
68
- if [[ " $1 " != * = * ]]; then shift ; fi
69
- COMPRESS=" ${1#* =} "
70
- ;;
71
- --calib_size* )
72
- if [[ " $1 " != * = * ]]; then shift ; fi
73
- CALIB_SIZE=" ${1#* =} "
74
- ;;
75
- --train_bs* )
76
- if [[ " $1 " != * = * ]]; then shift ; fi
77
- TRAIN_BS=" ${1#* =} "
78
- ;;
79
- --eval_bs* )
80
- if [[ " $1 " != * = * ]]; then shift ; fi
81
- EVAL_BS=" ${1#* =} "
82
- ;;
83
- --do_train* )
84
- if [[ " $1 " != * = * ]]; then shift ; fi
85
- DO_TRAIN=" ${1#* =} "
86
- ;;
87
- --lora* )
88
- if [[ " $1 " != * = * ]]; then shift ; fi
89
- LORA=" ${1#* =} "
90
- ;;
91
- --teacher_model* )
92
- if [[ " $1 " != * = * ]]; then shift ; fi
93
- TEACHER_MODEL=" ${1#* =} "
94
- ;;
95
- --distill* )
96
- if [[ " $1 " != * = * ]]; then shift ; fi
97
- DISTILL=" ${1#* =} "
98
- ;;
99
- --fsdp_transformer_layer_cls_to_wrap* )
100
- if [[ " $1 " != * = * ]]; then shift ; fi
101
- FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=" ${1#* =} "
102
- ;;
103
- --max_seq_length* )
104
- if [[ " $1 " != * = * ]]; then shift ; fi
105
- MAX_SEQ_LENGTH=" ${1#* =} "
106
- ;;
107
- --backend* )
108
- if [[ " $1 " != * = * ]]; then shift ; fi
109
- BACKEND=" ${1#* =} "
110
- ;;
111
- --use_fsdp2* )
112
- if [[ " $1 " != * = * ]]; then shift ; fi
113
- USE_FSDP2=" ${1#* =} "
114
- ;;
29
+ --model* ) MODEL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
30
+ --output_dir* ) OUTPUT_DIR=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
31
+ --dataset* ) DATASET=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
32
+ --train_size* ) TRAIN_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
33
+ --eval_size* ) EVAL_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
34
+ --num_epochs* ) NUM_EPOCHS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
35
+ --max_steps* ) MAX_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
36
+ --save_steps* ) SAVE_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
37
+ --accum_steps* ) ACCUM_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
38
+ --lr* ) LR=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
39
+ --quant_cfg* ) QUANT_CFG=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
40
+ --compress* ) COMPRESS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
41
+ --calib_size* ) CALIB_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
42
+ --train_bs* ) TRAIN_BS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
43
+ --eval_bs* ) EVAL_BS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
44
+ --do_train* ) DO_TRAIN=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
45
+ --lora* ) LORA=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
46
+ --teacher_model* ) TEACHER_MODEL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
47
+ --distill* ) DISTILL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
48
+ --fsdp_transformer_layer_cls_to_wrap* ) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
49
+ --max_seq_length* ) MAX_SEQ_LENGTH=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
50
+ --backend* ) BACKEND=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
51
+ --use_fsdp2* ) USE_FSDP2=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
115
52
* )
116
53
>&2 printf " Error: Invalid argument ${1#* =} \n"
117
54
exit 1
@@ -175,23 +112,19 @@ case "${BACKEND,,}" in
175
112
" fsdp1" |" fsdp" )
176
113
CONFIG_FILE=" fsdp1.yaml"
177
114
FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
178
- GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
179
115
;;
180
116
" fsdp2" )
181
117
echo " Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
182
118
CONFIG_FILE=" fsdp2.yaml"
183
119
FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
184
- GRADIENT_CHECKPOINTING_ARGS=" "
185
120
;;
186
121
" ddp" )
187
122
CONFIG_FILE=" ddp.yaml"
188
123
FSDP_ARGS=" "
189
- GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
190
124
;;
191
125
" deepspeed" )
192
126
CONFIG_FILE=" deepspeed.yaml"
193
127
FSDP_ARGS=" "
194
- GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
195
128
;;
196
129
* )
197
130
echo " Error: Invalid backend '$BACKEND '. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
@@ -239,7 +172,8 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
239
172
--report_to tensorboard \
240
173
--lora $LORA \
241
174
--compress $COMPRESS \
242
- $QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS
175
+ --gradient_checkpointing True \
176
+ $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
243
177
"
244
178
245
179
start_time=$( date +%s)
0 commit comments