Skip to content

Commit a6263a5

Browse files
authored
Merge branch 'main' into py3.14
2 parents a5ca7d4 + 452284b commit a6263a5

File tree

5 files changed

+148
-14
lines changed

5 files changed

+148
-14
lines changed

docs/source/index.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,44 @@
77
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
88
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
99

10+
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support).
11+
12+
<div style="display: flex; justify-content: space-between; width: 100%; gap: 2rem;">
13+
14+
<div style="flex: 1; min-width: 0;">
15+
16+
**Online methods**
17+
- [`GRPOTrainer`] ⚡️
18+
- [`RLOOTrainer`] ⚡️
19+
- [`OnlineDPOTrainer`] ⚡️
20+
- [`NashMDTrainer`] ⚡️
21+
- [`XPOTrainer`] ⚡️
22+
- [`PPOTrainer`]
23+
24+
**Reward modeling**
25+
- [`PRMTrainer`]
26+
- [`RewardTrainer`]
27+
28+
</div>
29+
30+
<div style="flex: 1; min-width: 0;">
31+
32+
**Offline methods**
33+
- [`SFTTrainer`]
34+
- [`DPOTrainer`]
35+
- [`ORPOTrainer`]
36+
- [`BCOTrainer`]
37+
- [`CPOTrainer`]
38+
- [`KTOTrainer`]
39+
40+
**Knowledge distillation**
41+
- [`GKDTrainer`]
42+
43+
</div>
44+
45+
</div>
46+
47+
1048
## 🎉 What's New
1149

1250
**✨ OpenAI GPT OSS Support**: TRL now fully supports fine-tuning the latest [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4)! Check out the:

docs/source/lora_without_regret.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ from trl import SFTTrainer, SFTConfig
4242

4343
dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
4444

45-
peft_config = LoraConfig(lora_r=256, lora_alpha=16, lora_target_modules="all-linear")
45+
peft_config = LoraConfig(r=256, lora_alpha=16, target_modules="all-linear")
4646

4747
training_args = SFTConfig(
4848
learning_rate=2e-4,
@@ -245,9 +245,9 @@ def strip_reasoning_accuracy_reward(completions, **kwargs):
245245
...
246246

247247
peft_config = LoraConfig(
248-
lora_r=1,
248+
r=1,
249249
lora_alpha=32,
250-
lora_target_modules="all-linear"
250+
target_modules="all-linear"
251251
)
252252

253253
training_args = GRPOConfig(
@@ -419,7 +419,7 @@ The blog post defines the ideal dataset size for LoRA to match full fine-tuning
419419

420420
### 3. *"FullFT and high-rank LoRAs have similar learning curves"*
421421

422-
Counterintuitively, the blog post recommends using similar learning rates to full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent.
422+
Counterintuitively, the blog post recommends using a higher learning rate than for full fine-tuning. In the table above, we used 1.0e-5 for LoRA and 1.0e-6 for full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent.
423423

424424
![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/2.png)
425425

scripts/generate_tiny_models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def init_weights_tiny_model(model):
155155
for model_id, config_class, model_class, suffix in [
156156
("bigscience/bloomz-560m", BloomConfig, BloomForCausalLM, None),
157157
("CohereForAI/aya-expanse-8b", CohereConfig, CohereForCausalLM, None),
158-
("databricks/dbrx-instruct", DbrxConfig, DbrxForCausalLM, None),
159158
("deepseek-ai/DeepSeek-R1", DeepseekV3Config, DeepseekV3ForCausalLM, None),
160159
# It's important to have R1-0528 as it doesn't have the same chat template
161160
("deepseek-ai/DeepSeek-R1-0528", DeepseekV3Config, DeepseekV3ForCausalLM, "0528"),
@@ -209,6 +208,17 @@ def init_weights_tiny_model(model):
209208
init_weights_tiny_model(model)
210209
push_to_hub(model, tokenizer, "tiny", suffix)
211210

211+
# Special case for databricks/dbrx-instruct as it requires specific changes in the config
212+
model_id = "databricks/dbrx-instruct"
213+
tokenizer = AutoTokenizer.from_pretrained(model_id)
214+
config = DbrxConfig.from_pretrained(model_id, n_layers=2, n_heads=16, d_model=24)
215+
# transformers mistakenly ignores ffn_config keys when loading from pretrained. We need to set them manually after
216+
# loading the config
217+
config.ffn_config.ffn_hidden_size = 24
218+
config.ffn_config.hidden_size = 24
219+
model = DbrxForCausalLM(config).to(dtype=torch.bfloat16)
220+
init_weights_tiny_model(model)
221+
push_to_hub(model, tokenizer, "tiny")
212222

213223
# Two slightly bigger models, required for vLLM testing
214224
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")

tests/test_modeling_value_head.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import pytest
1818
import torch
19+
import transformers
20+
from packaging import version
1921
from parameterized import parameterized
2022
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig
2123

@@ -63,6 +65,12 @@ def test_value_head(self):
6365
Test if the v-head is added to the model successfully
6466
"""
6567
for model_name in self.all_model_names:
68+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
69+
transformers.__version__
70+
) < version.parse("4.58.0.dev0"):
71+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
72+
continue
73+
6674
model = self.trl_model_class.from_pretrained(model_name)
6775
assert hasattr(model, "v_head")
6876

@@ -71,6 +79,12 @@ def test_value_head_shape(self):
7179
Test if the v-head has the correct shape
7280
"""
7381
for model_name in self.all_model_names:
82+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
83+
transformers.__version__
84+
) < version.parse("4.58.0.dev0"):
85+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
86+
continue
87+
7488
model = self.trl_model_class.from_pretrained(model_name)
7589
assert model.v_head.summary.weight.shape[0] == 1
7690

@@ -80,6 +94,12 @@ def test_value_head_init_random(self):
8094
than zeros by default.
8195
"""
8296
for model_name in self.all_model_names:
97+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
98+
transformers.__version__
99+
) < version.parse("4.58.0.dev0"):
100+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
101+
continue
102+
83103
model = self.trl_model_class.from_pretrained(model_name)
84104
assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
85105

@@ -89,6 +109,12 @@ def test_value_head_not_str(self):
89109
`from_pretrained`.
90110
"""
91111
for model_name in self.all_model_names:
112+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
113+
transformers.__version__
114+
) < version.parse("4.58.0.dev0"):
115+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
116+
continue
117+
92118
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
93119
model = self.trl_model_class.from_pretrained(pretrained_model)
94120
assert hasattr(model, "v_head")
@@ -99,6 +125,12 @@ def test_from_save_trl(self):
99125
additional modules (e.g. v_head)
100126
"""
101127
for model_name in self.all_model_names:
128+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
129+
transformers.__version__
130+
) < version.parse("4.58.0.dev0"):
131+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
132+
continue
133+
102134
model = self.trl_model_class.from_pretrained(model_name)
103135

104136
model.save_pretrained(self.tmp_dir)
@@ -114,6 +146,12 @@ def test_from_save_trl_sharded(self):
114146
Test if the model can be saved and loaded from a directory and get the same weights - sharded case
115147
"""
116148
for model_name in self.all_model_names:
149+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
150+
transformers.__version__
151+
) < version.parse("4.58.0.dev0"):
152+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
153+
continue
154+
117155
model = self.trl_model_class.from_pretrained(model_name)
118156

119157
model.save_pretrained(self.tmp_dir)
@@ -129,6 +167,12 @@ def test_from_save_transformers_sharded(self):
129167
Test if the model can be saved and loaded using transformers and get the same weights - sharded case
130168
"""
131169
for model_name in self.all_model_names:
170+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
171+
transformers.__version__
172+
) < version.parse("4.58.0.dev0"):
173+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
174+
continue
175+
132176
transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
133177

134178
trl_model = self.trl_model_class.from_pretrained(model_name)
@@ -150,6 +194,12 @@ def test_from_save_transformers(self):
150194
of the super class to check if the weights are the same.
151195
"""
152196
for model_name in self.all_model_names:
197+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
198+
transformers.__version__
199+
) < version.parse("4.58.0.dev0"):
200+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
201+
continue
202+
153203
transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
154204

155205
trl_model = self.trl_model_class.from_pretrained(model_name)
@@ -200,6 +250,12 @@ def test_inference(self):
200250
EXPECTED_OUTPUT_SIZE = 3
201251

202252
for model_name in self.all_model_names:
253+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
254+
transformers.__version__
255+
) < version.parse("4.58.0.dev0"):
256+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
257+
continue
258+
203259
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
204260
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
205261
outputs = model(input_ids)
@@ -213,6 +269,12 @@ def test_dropout_config(self):
213269
Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
214270
"""
215271
for model_name in self.all_model_names:
272+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
273+
transformers.__version__
274+
) < version.parse("4.58.0.dev0"):
275+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
276+
continue
277+
216278
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
217279
pretrained_model.config.summary_dropout_prob = 0.5
218280
model = self.trl_model_class.from_pretrained(pretrained_model)
@@ -225,6 +287,11 @@ def test_dropout_kwargs(self):
225287
Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
226288
"""
227289
for model_name in self.all_model_names:
290+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
291+
transformers.__version__
292+
) < version.parse("4.58.0.dev0"):
293+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
294+
continue
228295
v_head_kwargs = {"summary_dropout_prob": 0.5}
229296

230297
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
@@ -242,6 +309,12 @@ def test_generate(self, model_name):
242309
r"""
243310
Test if `generate` works for every model
244311
"""
312+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
313+
transformers.__version__
314+
) < version.parse("4.58.0.dev0"):
315+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
316+
pytest.xfail("DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version")
317+
245318
generation_config = GenerationConfig(max_new_tokens=9)
246319
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
247320
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
@@ -256,6 +329,12 @@ def test_transformers_bf16_kwargs(self):
256329
run a dummy forward pass without any issue.
257330
"""
258331
for model_name in self.all_model_names:
332+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
333+
transformers.__version__
334+
) < version.parse("4.58.0.dev0"):
335+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
336+
continue
337+
259338
trl_model = self.trl_model_class.from_pretrained(model_name, dtype=torch.bfloat16).to(self.device)
260339

261340
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
@@ -276,6 +355,12 @@ def test_transformers_bf16_kwargs(self):
276355
@pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.")
277356
def test_push_to_hub(self):
278357
for model_name in self.all_model_names:
358+
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
359+
transformers.__version__
360+
) < version.parse("4.58.0.dev0"):
361+
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
362+
continue
363+
279364
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
280365
if "sharded" in model_name:
281366
model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB")

trl/trainer/sft_trainer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
938938
prompt_ids = processing_class.apply_chat_template(
939939
example["prompt"],
940940
tokenize=True,
941+
add_generation_prompt=True,
941942
tools=example.get("tools"),
942943
**example.get("chat_template_kwargs", {}),
943944
)
@@ -975,7 +976,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
975976
"token handling. Verify that the tokenizer is processing text consistently."
976977
)
977978

978-
# Create a completion mask
979+
# Create completion mask
979980
completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids))
980981
output["input_ids"] = prompt_completion_ids
981982
output["completion_mask"] = completion_mask
@@ -995,17 +996,17 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
995996
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
996997
# even for single examples, while for LLMs it returns lists of ints.
997998
processed = {k: v[0] if isinstance(v[0], list) else v for k, v in processed.items()}
998-
if "assistant_masks" in processed and 1 not in processed["assistant_masks"]:
999-
raise RuntimeError(
1000-
"You're using `assistant_only_loss=True`, but at least one example has no "
1001-
"assistant tokens. This usually means the tokenizer's chat template doesn't "
1002-
"generate assistant masks — it may be missing the `{% generation %}` keyword. Please "
1003-
"check the template and ensure it's correctly configured to support assistant "
1004-
"masking."
1005-
)
1006999
output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed}
10071000
else:
10081001
output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]}
1002+
1003+
if "assistant_masks" in output and 1 not in output["assistant_masks"]:
1004+
raise RuntimeError(
1005+
"You're using `assistant_only_loss=True`, but at least one example has no assistant "
1006+
"tokens. This usually means the tokenizer's chat template doesn't generate assistant "
1007+
"masks — it may be missing the `{% generation %}` keyword. Please check the template and "
1008+
"ensure it's correctly configured to support assistant masking."
1009+
)
10091010
return output
10101011

10111012
dataset = dataset.map(

0 commit comments

Comments
 (0)