Skip to content

Commit eb51261

Browse files
committed
Update TRL version in example
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 7df1237 commit eb51261

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

examples/llm_distill/README.md

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -154,35 +154,47 @@ Keep in mind the training loss of the distillation run is not directly comparabl
154154
### Train teacher
155155

156156
```bash
157-
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
157+
accelerate launch \
158+
--multi_gpu \
159+
--mixed_precision bf16 \
160+
--fsdp_version 2 \
161+
--fsdp_reshard_after_forward True \
162+
--fsdp_auto_wrap_policy 'TRANSFORMER_BASED_WRAP' \
163+
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
164+
\
165+
main.py \
158166
--single_model \
159167
--teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \
160168
--output_dir ./llama2-7b-sft \
161169
--logging_steps 5 \
162170
--max_steps 400 \
163-
--max_seq_length 2048 \
171+
--max_length 2048 \
164172
--per_device_train_batch_size 1 \
165173
--per_device_eval_batch_size 4 \
166-
--gradient_checkpointing True \
167-
--fsdp 'full_shard auto_wrap' \
168-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
174+
--gradient_checkpointing True
169175
```
170176

171177
### Distill teacher into student
172178

173179
```bash
174-
accelerate launch --multi_gpu --mixed_precision bf16 main.py \
180+
accelerate launch \
181+
--multi_gpu \
182+
--mixed_precision bf16 \
183+
--fsdp_version 2 \
184+
--fsdp_reshard_after_forward True \
185+
--fsdp_auto_wrap_policy 'TRANSFORMER_BASED_WRAP' \
186+
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
187+
\
188+
main.py \
175189
--teacher_name_or_path ./llama2-7b-sft \
176190
--student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
177191
--output_dir ./llama2-distill \
178192
--logging_steps 5 \
179193
--max_steps 200 \
180-
--max_seq_length 2048 \
194+
--max_length 2048 \
181195
--per_device_train_batch_size 1 \
182196
--per_device_eval_batch_size 4 \
183-
--gradient_checkpointing False \
184-
--fsdp 'full_shard auto_wrap' \
185-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer
197+
--gradient_checkpointing False
186198
```
187199

188200
> [!NOTE]

examples/llm_distill/main.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,22 @@ class TrainingArguments(transformers.TrainingArguments):
4848
do_train: bool = True
4949
do_eval: bool = True
5050
save_strategy: str = "no"
51-
max_seq_length: int = 1024
51+
max_length: int = 1024
5252
optim: str = "adamw_torch"
5353
learning_rate: float = 1e-5
5454
lr_scheduler_type: str = "cosine"
5555
dataloader_drop_last: bool = True
5656
dataset_num_proc: int = 8
57-
dataset_batch_size: int = 500
5857
bf16: bool = True
5958
tf32: bool = True
6059

6160

6261
def llama_text_format_func(sample):
63-
texts = []
64-
for p, q, r in zip(sample["system_prompt"], sample["question"], sample["response"]):
65-
if not p:
66-
texts.append(f"<s>[INST] {q}[/INST]\n{r}</s>")
67-
else:
68-
texts.append(f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>")
69-
return texts
62+
p, q, r = sample["system_prompt"], sample["question"], sample["response"]
63+
if not p:
64+
return f"<s>[INST] {q}[/INST]\n{r}</s>"
65+
else:
66+
return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"
7067

7168

7269
class KDSFTTrainer(SFTTrainer, KDTrainer):
@@ -130,7 +127,6 @@ def train():
130127
kd_config = {
131128
"teacher_model": teacher_model,
132129
"criterion": LMLogitsLoss(),
133-
"expose_minimal_state_dict": False, # FSDP forces us to disable this
134130
}
135131
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
136132
logger.info("Models converted.")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow
2-
trl==0.13.0
2+
trl==0.23.0

0 commit comments

Comments
 (0)