@@ -18,96 +18,37 @@ set -eo pipefail
18
18
19
19
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
20
20
21
+ # Helper function to parse a single argument value
22
+ parse_value () {
23
+ if [[ " $1 " != * = * ]]; then shift ; fi
24
+ echo " ${1#* =} "
25
+ }
26
+
21
27
while [ $# -gt 0 ]; do
22
28
case " $1 " in
23
- --model* )
24
- if [[ " $1 " != * = * ]]; then shift ; fi
25
- MODEL=" ${1#* =} "
26
- ;;
27
- --output_dir* )
28
- if [[ " $1 " != * = * ]]; then shift ; fi
29
- OUTPUT_DIR=" ${1#* =} "
30
- ;;
31
- --dataset* )
32
- if [[ " $1 " != * = * ]]; then shift ; fi
33
- DATASET=" ${1#* =} "
34
- ;;
35
- --train_size* )
36
- if [[ " $1 " != * = * ]]; then shift ; fi
37
- TRAIN_SIZE=" ${1#* =} "
38
- ;;
39
- --eval_size* )
40
- if [[ " $1 " != * = * ]]; then shift ; fi
41
- EVAL_SIZE=" ${1#* =} "
42
- ;;
43
- --num_epochs* )
44
- if [[ " $1 " != * = * ]]; then shift ; fi
45
- NUM_EPOCHS=" ${1#* =} "
46
- ;;
47
- --max_steps* )
48
- if [[ " $1 " != * = * ]]; then shift ; fi
49
- MAX_STEPS=" ${1#* =} "
50
- ;;
51
- --save_steps* )
52
- if [[ " $1 " != * = * ]]; then shift ; fi
53
- SAVE_STEPS=" ${1#* =} "
54
- ;;
55
- --accum_steps* )
56
- if [[ " $1 " != * = * ]]; then shift ; fi
57
- ACCUM_STEPS=" ${1#* =} "
58
- ;;
59
- --lr* )
60
- if [[ " $1 " != * = * ]]; then shift ; fi
61
- LR=" ${1#* =} "
62
- ;;
63
- --quant_cfg* )
64
- if [[ " $1 " != * = * ]]; then shift ; fi
65
- QUANT_CFG=" ${1#* =} "
66
- ;;
67
- --compress* )
68
- if [[ " $1 " != * = * ]]; then shift ; fi
69
- COMPRESS=" ${1#* =} "
70
- ;;
71
- --calib_size* )
72
- if [[ " $1 " != * = * ]]; then shift ; fi
73
- CALIB_SIZE=" ${1#* =} "
74
- ;;
75
- --train_bs* )
76
- if [[ " $1 " != * = * ]]; then shift ; fi
77
- TRAIN_BS=" ${1#* =} "
78
- ;;
79
- --eval_bs* )
80
- if [[ " $1 " != * = * ]]; then shift ; fi
81
- EVAL_BS=" ${1#* =} "
82
- ;;
83
- --do_train* )
84
- if [[ " $1 " != * = * ]]; then shift ; fi
85
- DO_TRAIN=" ${1#* =} "
86
- ;;
87
- --lora* )
88
- if [[ " $1 " != * = * ]]; then shift ; fi
89
- LORA=" ${1#* =} "
90
- ;;
91
- --teacher_model* )
92
- if [[ " $1 " != * = * ]]; then shift ; fi
93
- TEACHER_MODEL=" ${1#* =} "
94
- ;;
95
- --distill* )
96
- if [[ " $1 " != * = * ]]; then shift ; fi
97
- DISTILL=" ${1#* =} "
98
- ;;
99
- --fsdp_transformer_layer_cls_to_wrap* )
100
- if [[ " $1 " != * = * ]]; then shift ; fi
101
- FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=" ${1#* =} "
102
- ;;
103
- --use_fsdp2* )
104
- if [[ " $1 " != * = * ]]; then shift ; fi
105
- USE_FSDP2=" ${1#* =} "
106
- ;;
107
- --max_seq_length* )
108
- if [[ " $1 " != * = * ]]; then shift ; fi
109
- MAX_SEQ_LENGTH=" ${1#* =} "
110
- ;;
29
+ --model* ) MODEL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
30
+ --output_dir* ) OUTPUT_DIR=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
31
+ --dataset* ) DATASET=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
32
+ --train_size* ) TRAIN_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
33
+ --eval_size* ) EVAL_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
34
+ --num_epochs* ) NUM_EPOCHS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
35
+ --max_steps* ) MAX_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
36
+ --save_steps* ) SAVE_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
37
+ --accum_steps* ) ACCUM_STEPS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
38
+ --lr* ) LR=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
39
+ --quant_cfg* ) QUANT_CFG=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
40
+ --compress* ) COMPRESS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
41
+ --calib_size* ) CALIB_SIZE=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
42
+ --train_bs* ) TRAIN_BS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
43
+ --eval_bs* ) EVAL_BS=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
44
+ --do_train* ) DO_TRAIN=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
45
+ --lora* ) LORA=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
46
+ --teacher_model* ) TEACHER_MODEL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
47
+ --distill* ) DISTILL=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
48
+ --fsdp_transformer_layer_cls_to_wrap* ) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
49
+ --max_seq_length* ) MAX_SEQ_LENGTH=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
50
+ --backend* ) BACKEND=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
51
+ --use_fsdp2* ) USE_FSDP2=$( parse_value " $@ " ) ; [[ " $1 " != * = * ]] && shift ;;
111
52
* )
112
53
>&2 printf " Error: Invalid argument ${1#* =} \n"
113
54
exit 1
@@ -142,6 +83,7 @@ COMPRESS=${COMPRESS:-"False"}
142
83
DISTILL=${DISTILL:- " False" }
143
84
TEACHER_MODEL=${TEACHER_MODEL:- $MODEL }
144
85
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:- " LlamaDecoderLayer" }
86
+ BACKEND=${BACKEND:- " fsdp1" }
145
87
146
88
if [ -z $QUANT_CFG ]; then
147
89
QUANT_ARGS=" "
@@ -154,31 +96,55 @@ if [ ! -z $MAX_STEPS ]; then
154
96
OPTIONAL_ARGS=" $OPTIONAL_ARGS --max_steps $MAX_STEPS "
155
97
fi
156
98
157
- CONFIG_FILE=" fsdp1.yaml"
158
- FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
159
- GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
160
-
99
+ # Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
161
100
if [[ " ${USE_FSDP2,,} " == " true" ]]; then
162
- echo " Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
163
- CONFIG_FILE=" fsdp2.yaml"
164
- GRADIENT_CHECKPOINTING_ARGS=" "
101
+ echo " Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
102
+ BACKEND=" fsdp2"
103
+ fi
104
+
105
+ # if compress is true, set backend to ddp
106
+ if [[ " ${COMPRESS,,} " == " true" ]]; then
107
+ BACKEND=" ddp"
165
108
fi
166
109
110
+ # Configure backend-specific settings
111
+ GRADIENT_CHECKPOINTING_ARGS=" "
112
+ case " ${BACKEND,,} " in
113
+ " fsdp1" |" fsdp" )
114
+ CONFIG_FILE=" fsdp1.yaml"
115
+ FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
116
+ ;;
117
+ " fsdp2" )
118
+ echo " Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
119
+ CONFIG_FILE=" fsdp2.yaml"
120
+ FSDP_ARGS=" --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP "
121
+ ;;
122
+ " ddp" )
123
+ CONFIG_FILE=" ddp.yaml"
124
+ FSDP_ARGS=" "
125
+ GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
126
+ ;;
127
+ " deepspeed" )
128
+ CONFIG_FILE=" deepspeed.yaml"
129
+ FSDP_ARGS=" "
130
+ GRADIENT_CHECKPOINTING_ARGS=" --gradient_checkpointing True"
131
+ ;;
132
+ * )
133
+ echo " Error: Invalid backend '$BACKEND '. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
134
+ exit 1
135
+ ;;
136
+ esac
137
+
138
+ # TODO: Remove this after simple distillation is supported
167
139
DISTILLATION_ARGS=" "
168
140
if [[ " ${DISTILL,,} " == " true" ]]; then
169
141
DISTILLATION_ARGS=" --distill $DISTILL --teacher_model $TEACHER_MODEL "
170
- # Distillation does not work with memory efficient loading
171
- FSDP_ARGS=" $FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
142
+ # Distillation does not work with memory efficient loading for FSDP
143
+ if [[ " ${BACKEND,,} " == " fsdp1" || " ${BACKEND,,} " == " fsdp2" ]]; then
144
+ FSDP_ARGS=" $FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
145
+ fi
172
146
fi
173
147
174
- # real quantization does not work with FSDP, only works with FSDP2
175
- if [[ " ${COMPRESS,,} " == " true" && " ${USE_FSDP2,,} " != " true" ]]; then
176
- echo " Compression is not supported with FSDP. Disabling FSDP and using DDP."
177
- FSDP_ARGS=" "
178
- CONFIG_FILE=" ddp.yaml"
179
- fi
180
-
181
-
182
148
CMD=" accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
183
149
main.py \
184
150
--model_name_or_path $MODEL \
@@ -209,10 +175,10 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
209
175
--report_to tensorboard \
210
176
--lora $LORA \
211
177
--compress $COMPRESS \
212
- $QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS
178
+ $GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
213
179
"
214
180
215
181
start_time=$( date +%s)
216
182
sh -c " $CMD "
217
183
echo " Total time taken: $(( $(date +% s) - $start_time )) seconds"
218
- python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
184
+ python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
0 commit comments