@@ -18,96 +18,37 @@ set -eo pipefail
1818
1919export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
2020
21+ # Helper function to parse a single argument value
22+ parse_value () {
23+ if [[ " $1 " != * = * ]]; then shift ; fi
24+ echo " ${1#* =} "
25+ }
26+
2127while [ $# -gt 0 ]; do
2228 case " $1 " in
23- --model* )
24- if [[ " $1 " != * = * ]]; then shift ; fi
25- MODEL=" ${1#* =} "
26- ;;
27- --output_dir* )
28- if [[ " $1 " != * = * ]]; then shift ; fi
29- OUTPUT_DIR=" ${1#* =} "
30- ;;
31- --dataset* )
32- if [[ " $1 " != * = * ]]; then shift ; fi
33- DATASET=" ${1#* =} "
34- ;;
35- --train_size* )
36- if [[ " $1 " != * = * ]]; then shift ; fi
37- TRAIN_SIZE=" ${1#* =} "
38- ;;
39- --eval_size* )
40- if [[ " $1 " != * = * ]]; then shift ; fi
41- EVAL_SIZE=" ${1#* =} "
42- ;;
43- --num_epochs* )
44- if [[ " $1 " != * = * ]]; then shift ; fi
45- NUM_EPOCHS=" ${1#* =} "
46- ;;
47- --max_steps* )
48- if [[ " $1 " != * = * ]]; then shift ; fi
49- MAX_STEPS=" ${1#* =} "
50- ;;
51- --save_steps* )
52- if [[ " $1 " != * = * ]]; then shift ; fi
53- SAVE_STEPS=" ${1#* =} "
54- ;;
55- --accum_steps* )
56- if [[ " $1 " != * = * ]]; then shift ; fi
57- ACCUM_STEPS=" ${1#* =} "
58- ;;
59- --lr* )
60- if [[ " $1 " != * = * ]]; then shift ; fi
61- LR=" ${1#* =} "
62- ;;
63- --quant_cfg* )
64- if [[ " $1 " != * = * ]]; then shift ; fi
65- QUANT_CFG=" ${1#* =} "
66- ;;
67- --compress* )
68- if [[ " $1 " != * = * ]]; then shift ; fi
69- COMPRESS=" ${1#* =} "
70- ;;
71- --calib_size* )
72- if [[ " $1 " != * = * ]]; then shift ; fi
73- CALIB_SIZE=" ${1#* =} "
74- ;;
75- --train_bs* )
76- if [[ " $1 " != * = * ]]; then shift ; fi
77- TRAIN_BS=" ${1#* =} "
78- ;;
79- --eval_bs* )
80- if [[ " $1 " != * = * ]]; then shift ; fi
81- EVAL_BS=" ${1#* =} "
82- ;;
83- --do_train* )
84- if [[ " $1 " != * = * ]]; then shift ; fi
85- DO_TRAIN=" ${1#* =} "
86- ;;
87- --lora* )
88- if [[ " $1 " != * = * ]]; then shift ; fi
89- LORA=" ${1#* =} "
90- ;;
91- --teacher_model* )
92- if [[ " $1 " != * = * ]]; then shift ; fi
93- TEACHER_MODEL=" ${1#* =} "
94- ;;
95- --distill* )
96- if [[ " $1 " != * = * ]]; then shift ; fi
97- DISTILL=" ${1#* =} "
98- ;;
99- --fsdp_transformer_layer_cls_to_wrap* )
100- if [[ " $1 " != * = * ]]; then shift ; fi
101- FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=" ${1#* =} "
102- ;;
103- --use_fsdp2* )
104- if [[ " $1 " != * = * ]]; then shift ; fi
105- USE_FSDP2=" ${1#* =} "
106- ;;
107- --max_seq_length* )
108- if [[ " $1 " != * = * ]]; then shift ; fi
109- MAX_SEQ_LENGTH=" ${1#* =} "
110- ;;
29+ --model* ) MODEL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
30+ --output_dir* ) OUTPUT_DIR=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
31+ --dataset* ) DATASET=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
32+ --train_size* ) TRAIN_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
33+ --eval_size* ) EVAL_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
34+ --num_epochs* ) NUM_EPOCHS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
35+ --max_steps* ) MAX_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
36+ --save_steps* ) SAVE_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
37+ --accum_steps* ) ACCUM_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
38+ --lr* ) LR=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
39+ --quant_cfg* ) QUANT_CFG=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
40+ --compress* ) COMPRESS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
41+ --calib_size* ) CALIB_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
42+ --train_bs* ) TRAIN_BS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
43+ --eval_bs* ) EVAL_BS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
44+ --do_train* ) DO_TRAIN=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
45+ --lora* ) LORA=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
46+ --teacher_model* ) TEACHER_MODEL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
47+ --distill* ) DISTILL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
48+ --fsdp_transformer_layer_cls_to_wrap* ) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
49+ --max_seq_length* ) MAX_SEQ_LENGTH=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
50+ --backend* ) BACKEND=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
51+ --use_fsdp2* ) USE_FSDP2=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
11152 * )
11253 >&2 printf " Error: Invalid argument ${1#* =} \n"
11354 exit 1
@@ -142,6 +83,7 @@ COMPRESS=${COMPRESS:-"False"}
14283DISTILL=${DISTILL:- " False" }
14384TEACHER_MODEL=${TEACHER_MODEL:- $MODEL }
14485FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:- " LlamaDecoderLayer" }
86+ BACKEND=${BACKEND:- " fsdp1" }
14587
14688if [ -z $QUANT_CFG ]; then
14789 QUANT_ARGS=" "
@@ -154,31 +96,55 @@ if [ ! -z $MAX_STEPS ]; then
15496 OPTIONAL_ARGS=" $OPTIONAL_ARGS --max_steps $MAX_STEPS "
15597fi
15698
157- CONFIG_FILE=" fsdp1.yaml"
158- FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
159- GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
160-
99+ # Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
161100if [[ " ${USE_FSDP2,,} " == " true" ]]; then
162- echo " Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
163- CONFIG_FILE=" fsdp2.yaml"
164- GRADIENT_CHECKPOINTING_ARGS=" "
101+ echo " Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
102+ BACKEND=" fsdp2"
103+ fi
104+
105+ # if compress is true, set backend to ddp
106+ if [[ " ${COMPRESS,,} " == " true" ]]; then
107+ BACKEND=" ddp"
165108fi
166109
110+ # Configure backend-specific settings
111+ GRADIENT_CHECKPOINTING_ARGS=" "
112+ case " ${BACKEND,,} " in
113+ " fsdp1" |" fsdp" )
114+ CONFIG_FILE=" fsdp1.yaml"
115+ FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
116+ ;;
117+ " fsdp2" )
118+ echo " Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
119+ CONFIG_FILE=" fsdp2.yaml"
120+ FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
121+ ;;
122+ " ddp" )
123+ CONFIG_FILE=" ddp.yaml"
124+ FSDP_ARGS=" "
125+ GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
126+ ;;
127+ " deepspeed" )
128+ CONFIG_FILE=" deepspeed.yaml"
129+ FSDP_ARGS=" "
130+ GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
131+ ;;
132+ * )
133+ echo " Error: Invalid backend '$BACKEND '. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
134+ exit 1
135+ ;;
136+ esac
137+
138+ # TODO: Remove this after simple distillation is supported
167139DISTILLATION_ARGS=" "
168140if [[ " ${DISTILL,,} " == " true" ]]; then
169141 DISTILLATION_ARGS=" --distill $DISTILL --teacher_model $TEACHER_MODEL "
170- # Distillation does not work with memory efficient loading
171- FSDP_ARGS=" $FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
142+ # Distillation does not work with memory efficient loading for FSDP
143+ if [[ " ${BACKEND,,} " == " fsdp1" || " ${BACKEND,,} " == " fsdp2" ]]; then
144+ FSDP_ARGS=" $FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
145+ fi
172146fi
173147
174- # real quantization does not work with FSDP, only works with FSDP2
175- if [[ " ${COMPRESS,,} " == " true" && " ${USE_FSDP2,,} " != " true" ]]; then
176- echo " Compression is not supported with FSDP. Disabling FSDP and using DDP."
177- FSDP_ARGS=" "
178- CONFIG_FILE=" ddp.yaml"
179- fi
180-
181-
182148CMD=" accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
183149 main.py \
184150 --model_name_or_path $MODEL \
@@ -209,10 +175,9 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
209175 --report_to tensorboard \
210176 --lora $LORA \
211177 --compress $COMPRESS \
212- $QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS
178+ $GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
213179"
214180
215181start_time=$( date +%s)
216182sh -c " $CMD "
217- echo " Total time taken: $(( $(date +% s) - $start_time )) seconds"
218- python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
183+ echo " Total time taken: $(( $(date +% s) - $start_time )) seconds"
0 commit comments