Skip to content

Commit ad554c6

Browse files
authored
Merge branch 'support-seq-cls-clone-chat' into reward-refactor
2 parents 95fe6b8 + 424d50d commit ad554c6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+2402
-1966
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
66
COMMAND_FILES_PATH = `pwd`/commands
77

88
test:
9-
pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
9+
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
1010

1111
precommit:
1212
python scripts/add_copyrights.py

docs/source/community_tutorials.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,10 @@ Community tutorials are made by active members of the Hugging Face community who
2929
<details>
3030
<summary>⚠️ Deprecated features notice for "How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset" (click to expand)</summary>
3131

32-
<Tip warning={true}>
33-
34-
The tutorial uses two deprecated features:
35-
- `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model).
36-
- `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy.
37-
38-
</Tip>
32+
> [!WARNING]
33+
> The tutorial uses two deprecated features:
34+
> - `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model).
35+
> - `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy.
3936
4037
</details>
4138

docs/source/dataset_formats.md

Lines changed: 48 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -289,31 +289,28 @@ prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the
289289

290290
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
291291

292-
<Tip>
293-
294-
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
295-
296-
```python
297-
from transformers import AutoTokenizer
298-
from trl import apply_chat_template
299-
300-
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
301-
302-
# Example for prompt-only type
303-
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
304-
apply_chat_template(prompt_only_example, tokenizer)
305-
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
306-
307-
# Example for language modeling type
308-
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
309-
apply_chat_template(lm_example, tokenizer)
310-
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
311-
```
312-
313-
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
314-
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
315-
316-
</Tip>
292+
> [!TIP]
293+
> While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
294+
>
295+
> ```python
296+
> from transformers import AutoTokenizer
297+
> from trl import apply_chat_template
298+
>
299+
> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
300+
>
301+
> # Example for prompt-only type
302+
> prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
303+
> apply_chat_template(prompt_only_example, tokenizer)
304+
> # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
305+
>
306+
> # Example for language modeling type
307+
> lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
308+
> apply_chat_template(lm_example, tokenizer)
309+
> # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
310+
> ```
311+
>
312+
> - The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
313+
> - In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
317314
318315
#### Prompt-completion
319316
@@ -408,12 +405,9 @@ Choosing the right dataset type depends on the task you are working on and the s
408405
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
409406
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
410407

411-
<Tip>
412-
413-
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
414-
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
415-
416-
</Tip>
408+
> [!TIP]
409+
> TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
410+
> For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
417411
418412
## Working with conversational datasets in TRL
419413

@@ -465,27 +459,21 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
465459
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
466460
```
467461

468-
<Tip warning={true}>
469-
470-
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
471-
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
472-
473-
</Tip>
474-
475-
<Tip warning={true}>
476-
477-
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
478-
479-
```python
480-
apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
481-
# Output:
482-
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
483-
# 'completion': 'It is blue.<|im_end|>\n'}
484-
```
485-
486-
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
487-
488-
</Tip>
462+
> [!WARNING]
463+
> We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
464+
> For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
465+
466+
> [!WARNING]
467+
> It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
468+
>
469+
> ```python
470+
> apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
471+
> # Output:
472+
> # {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
473+
> # 'completion': 'It is blue.<|im_end|>\n'}
474+
> ```
475+
>
476+
> Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
489477
490478
## Using any dataset with TRL: preprocessing and conversion
491479
@@ -715,13 +703,10 @@ dataset = unpair_preference_dataset(dataset)
715703
'label': True}
716704
```
717705

718-
<Tip warning={true}>
719-
720-
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
721-
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
722-
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
723-
724-
</Tip>
706+
> [!WARNING]
707+
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
708+
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
709+
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
725710
726711
### From preference to language modeling dataset
727712

@@ -856,13 +841,10 @@ dataset = unpair_preference_dataset(dataset)
856841
'label': True}
857842
```
858843

859-
<Tip warning={true}>
860-
861-
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
862-
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
863-
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
864-
865-
</Tip>
844+
> [!WARNING]
845+
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
846+
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
847+
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
866848
867849
### From unpaired preference to language modeling dataset
868850

docs/source/deepspeed_integration.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
# DeepSpeed Integration
22

3-
<Tip warning={true}>
4-
5-
Section under construction. Feel free to contribute!
6-
7-
</Tip>
3+
> [!WARNING]
4+
> Section under construction. Feel free to contribute!
85
96
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
107

docs/source/distributing_training.md

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Distributing Training
22

3-
<Tip warning={true}>
4-
Section under construction. Feel free to contribute!
5-
</Tip>
3+
> [!WARNING]
4+
> Section under construction. Feel free to contribute!
65
76
## Multi-GPU Training with TRL
87

@@ -49,11 +48,8 @@ Example, these configurations are equivalent, and should yield the same results:
4948
| 1 | 4 | 8 | Lower memory usage, slower training |
5049
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
5150

52-
<Tip>
53-
54-
Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details.
55-
56-
</Tip>
51+
> [!TIP]
52+
> Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details.
5753
5854
## Context Parallelism
5955

@@ -176,13 +172,10 @@ These results show that **Context Parallelism (CP) scales effectively with more
176172
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/context_parallelism_s_it_plot.png" alt="CP seconds/iteration" width="45%"/>
177173
</div>
178174

179-
<Tip>
180-
181-
Accelerate also supports **N-Dimensional Parallelism (ND-parallelism)**, which enables you to combine different parallelization strategies to efficiently distribute model training across multiple GPUs.
182-
183-
You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism).
184-
185-
</Tip>
175+
> [!TIP]
176+
> Accelerate also supports **N-Dimensional Parallelism (ND-parallelism)**, which enables you to combine different parallelization strategies to efficiently distribute model training across multiple GPUs.
177+
>
178+
> You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism).
186179
187180

188181
**Further Reading on Context Parallelism**

docs/source/dpo_trainer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,7 @@ dpo_trainer = DPOTrainer(
295295
## DataCollatorForPreference
296296

297297
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
298+
299+
## FDivergenceType
300+
301+
[[autodoc]] trainer.dpo_trainer.FDivergenceType

docs/source/experimental.md

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
The `trl.experimental` namespace provides a minimal, clearly separated space for fast iteration on new ideas.
44

5-
<Tip warning={true}>
6-
7-
**Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.
8-
9-
</Tip>
5+
> [!WARNING]
6+
> **Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.
107
118
## Current Experimental Features
129

@@ -66,7 +63,7 @@ class GroupFilter:
6663
return group_scores
6764

6865
training_args = GFPOConfig(
69-
output_dir="Qwen3-0.6B-GFPO"
66+
output_dir="Qwen3-0.6B-GFPO",
7067
per_device_train_batch_size=4,
7168
num_remains_in_group=2,
7269
bf16=True,
@@ -81,10 +78,61 @@ trainer = GFPOTrainer(
8178
trainer.train()
8279
```
8380

84-
## Usage
81+
### GSPO-token
82+
83+
In the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the `GRPOTrainer` class in `trl.experimental.gspo_token`.
84+
85+
```python
86+
from trl.experimental.gspo_token import GRPOTrainer
87+
from trl import GRPOConfig
88+
89+
training_args = GRPOConfig(
90+
importance_sampling_level="sequence_token",
91+
...
92+
)
93+
```
94+
95+
> [!WARNING]
96+
> To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.
97+
98+
### GRPO With Replay Buffer
99+
100+
This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.
101+
102+
#### Usage
85103

86104
```python
87-
from trl.experimental.new_trainer import NewTrainer
105+
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferTrainer
106+
from datasets import load_dataset
107+
108+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
109+
110+
# Guarantee that some rewards have 0 std
111+
def custom_reward_func(completions, **kwargs):
112+
if torch.rand(1).item() < 0.25:
113+
return [0] * len(completions) # simulate some None rewards
114+
else:
115+
return torch.rand(len(completions)).tolist()
116+
117+
training_args = GRPOWithReplayBufferConfig(
118+
output_dir=self.tmp_dir,
119+
learning_rate=1e-4,
120+
per_device_train_batch_size=4,
121+
num_generations=4,
122+
max_completion_length=8,
123+
replay_buffer_size=8,
124+
report_to="none",
125+
)
126+
trainer = GRPOTrainer(
127+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
128+
reward_funcs=[custom_reward_func],
129+
args=training_args,
130+
train_dataset=dataset,
131+
)
132+
133+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
134+
135+
trainer.train()
88136
```
89137

90138
To silence the runtime notice:

0 commit comments

Comments
 (0)