@@ -100,14 +100,18 @@ while [ $# -gt 0 ]; do
100
100
if [[ " $1 " != * = * ]]; then shift ; fi
101
101
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=" ${1#* =} "
102
102
;;
103
- --use_fsdp2* )
104
- if [[ " $1 " != * = * ]]; then shift ; fi
105
- USE_FSDP2=" ${1#* =} "
106
- ;;
107
103
--max_seq_length* )
108
104
if [[ " $1 " != * = * ]]; then shift ; fi
109
105
MAX_SEQ_LENGTH=" ${1#* =} "
110
106
;;
107
+ --backend* )
108
+ if [[ " $1 " != * = * ]]; then shift ; fi
109
+ BACKEND=" ${1#* =} "
110
+ ;;
111
+ --use_fsdp2* )
112
+ if [[ " $1 " != * = * ]]; then shift ; fi
113
+ USE_FSDP2=" ${1#* =} "
114
+ ;;
111
115
* )
112
116
>&2 printf " Error: Invalid argument ${1#* =} \n"
113
117
exit 1
@@ -142,6 +146,7 @@ COMPRESS=${COMPRESS:-"False"}
142
146
DISTILL=${DISTILL:- " False" }
143
147
TEACHER_MODEL=${TEACHER_MODEL:- $MODEL }
144
148
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:- " LlamaDecoderLayer" }
149
+ BACKEND=${BACKEND:- " fsdp1" }
145
150
146
151
if [ -z $QUANT_CFG ]; then
147
152
QUANT_ARGS=" "
@@ -154,31 +159,56 @@ if [ ! -z $MAX_STEPS ]; then
154
159
OPTIONAL_ARGS=" $OPTIONAL_ARGS --max_steps $MAX_STEPS "
155
160
fi
156
161
157
- CONFIG_FILE=" fsdp1.yaml"
158
- FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
159
- GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
160
-
162
+ # Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
161
163
if [[ " ${USE_FSDP2,,} " == " true" ]]; then
162
- echo " Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
163
- CONFIG_FILE=" fsdp2.yaml"
164
- GRADIENT_CHECKPOINTING_ARGS=" "
164
+ echo " Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
165
+ BACKEND=" fsdp2"
166
+ fi
167
+
168
+ # if compress is true, set backend to ddp
169
+ if [[ " ${COMPRESS,,} " == " true" ]]; then
170
+ BACKEND=" ddp"
165
171
fi
166
172
173
+ # Configure backend-specific settings
174
+ case " ${BACKEND,,} " in
175
+ " fsdp1" |" fsdp" )
176
+ CONFIG_FILE=" fsdp1.yaml"
177
+ FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
178
+ GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
179
+ ;;
180
+ " fsdp2" )
181
+ echo " Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
182
+ CONFIG_FILE=" fsdp2.yaml"
183
+ FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
184
+ GRADIENT_CHECKPOINTING_ARGS=" "
185
+ ;;
186
+ " ddp" )
187
+ CONFIG_FILE=" ddp.yaml"
188
+ FSDP_ARGS=" "
189
+ GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
190
+ ;;
191
+ " deepspeed" )
192
+ CONFIG_FILE=" deepspeed.yaml"
193
+ FSDP_ARGS=" "
194
+ GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
195
+ ;;
196
+ * )
197
+ echo " Error: Invalid backend '$BACKEND '. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
198
+ exit 1
199
+ ;;
200
+ esac
201
+
202
+ # TODO: Remove this after simple distillation is supported
167
203
DISTILLATION_ARGS=" "
168
204
if [[ " ${DISTILL,,} " == " true" ]]; then
169
205
DISTILLATION_ARGS=" --distill $DISTILL --teacher_model $TEACHER_MODEL "
170
- # Distillation does not work with memory efficient loading
171
- FSDP_ARGS=" $FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
206
+ # Distillation does not work with memory efficient loading for FSDP
207
+ if [[ " ${BACKEND,,} " == " fsdp1" || " ${BACKEND,,} " == " fsdp2" ]]; then
208
+ FSDP_ARGS=" $FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
209
+ fi
172
210
fi
173
211
174
- # real quantization does not work with FSDP, only works with FSDP2
175
- if [[ " ${COMPRESS,,} " == " true" && " ${USE_FSDP2,,} " != " true" ]]; then
176
- echo " Compression is not supported with FSDP. Disabling FSDP and using DDP."
177
- FSDP_ARGS=" "
178
- CONFIG_FILE=" ddp.yaml"
179
- fi
180
-
181
-
182
212
CMD=" accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
183
213
main.py \
184
214
--model_name_or_path $MODEL \
@@ -214,5 +244,4 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
214
244
215
245
start_time=$( date +%s)
216
246
sh -c " $CMD "
217
- echo " Total time taken: $(( $(date +% s) - $start_time )) seconds"
218
- python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
247
+ echo " Total time taken: $(( $(date +% s) - $start_time )) seconds"
0 commit comments