Skip to content

Commit 2c6e0d9

Browse files
authored
Add note about special tokens in chat templates for LoRA SFT (huggingface#2414)
1 parent e1d7813 commit 2c6e0d9

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

docs/source/sft_trainer.mdx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,33 +331,38 @@ Note that all keyword arguments of `from_pretrained()` are supported.
331331

332332
### Training adapters
333333

334-
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
334+
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model.
335335

336336
```python
337337
from datasets import load_dataset
338338
from trl import SFTConfig, SFTTrainer
339339
from peft import LoraConfig
340340

341-
dataset = load_dataset("stanfordnlp/imdb", split="train")
341+
dataset = load_dataset("trl-lib/Capybara", split="train")
342342

343343
peft_config = LoraConfig(
344344
r=16,
345345
lora_alpha=32,
346346
lora_dropout=0.05,
347-
bias="none",
347+
target_modules="all-linear",
348+
modules_to_save=["lm_head", "embed_token"],
348349
task_type="CAUSAL_LM",
349350
)
350351

351352
trainer = SFTTrainer(
352-
"EleutherAI/gpt-neo-125m",
353+
"Qwen/Qwen2.5-0.5B",
353354
train_dataset=dataset,
354-
args=SFTConfig(output_dir="/tmp"),
355+
args=SFTConfig(output_dir="Qwen2.5-0.5B-SFT"),
355356
peft_config=peft_config
356357
)
357358

358359
trainer.train()
359360
```
360361

362+
> [!WARNING]
363+
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
364+
365+
361366
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
362367

363368
### Training adapters with base 8 bit models

0 commit comments

Comments
 (0)