@@ -154,35 +154,47 @@ Keep in mind the training loss of the distillation run is not directly comparabl
154
154
### Train teacher
155
155
156
156
``` bash
157
- accelerate launch --multi_gpu --mixed_precision bf16 main.py \
157
+ accelerate launch \
158
+ --multi_gpu \
159
+ --mixed_precision bf16 \
160
+ --fsdp_version 2 \
161
+ --fsdp_reshard_after_forward True \
162
+ --fsdp_auto_wrap_policy ' TRANSFORMER_BASED_WRAP' \
163
+ --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
164
+ \
165
+ main.py \
158
166
--single_model \
159
167
--teacher_name_or_path ' meta-llama/Llama-2-7b-hf' \
160
168
--output_dir ./llama2-7b-sft \
161
169
--logging_steps 5 \
162
170
--max_steps 400 \
163
- --max_seq_length 2048 \
171
+ --max_length 2048 \
164
172
--per_device_train_batch_size 1 \
165
173
--per_device_eval_batch_size 4 \
166
- --gradient_checkpointing True \
167
- --fsdp ' full_shard auto_wrap' \
168
- --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
174
+ --gradient_checkpointing True
169
175
```
170
176
171
177
### Distill teacher into student
172
178
173
179
``` bash
174
- accelerate launch --multi_gpu --mixed_precision bf16 main.py \
180
+ accelerate launch \
181
+ --multi_gpu \
182
+ --mixed_precision bf16 \
183
+ --fsdp_version 2 \
184
+ --fsdp_reshard_after_forward True \
185
+ --fsdp_auto_wrap_policy ' TRANSFORMER_BASED_WRAP' \
186
+ --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
187
+ \
188
+ main.py \
175
189
--teacher_name_or_path ./llama2-7b-sft \
176
190
--student_name_or_path ' TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
177
191
--output_dir ./llama2-distill \
178
192
--logging_steps 5 \
179
193
--max_steps 200 \
180
- --max_seq_length 2048 \
194
+ --max_length 2048 \
181
195
--per_device_train_batch_size 1 \
182
196
--per_device_eval_batch_size 4 \
183
- --gradient_checkpointing False \
184
- --fsdp ' full_shard auto_wrap' \
185
- --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
197
+ --gradient_checkpointing False
186
198
```
187
199
188
200
> [ !NOTE]
0 commit comments