Skip to content
Open
24 changes: 24 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ Reward functions can be either synchronous Python callables or asynchronous `asy
- `completions` (contains the generated completions),
- `completion_ids` (contains the tokenized completions),
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
- `log_extra`: a callable `log_extra(column: str, values: list)` to add extra columns to the completions table. See Example 6.
- `log_metric`: a callable `log_metric(name: str, value: float)` to log scalar metrics as plots alongside `kl`, `entropy`, etc. See Example 6.
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.

The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
Expand Down Expand Up @@ -558,6 +560,28 @@ async def async_reward_func(prompts, completions, **kwargs):
return [1.0 if completion else 0.0 for completion in completions]
```

#### Example 6: Logging extra columns and metrics

Below is an example of a reward function that logs extra columns to the completions table and scalar metrics as plots.

```python
import re

def reward_func(completions, ground_truth, log_extra=None, log_metric=None, **kwargs):
extracted = [re.search(r"\\boxed\{(.*?)\}", c) for c in completions]
extracted = [m.group(1) if m else None for m in extracted]
rewards = [1.0 if e == gt else 0.0 for e, gt in zip(extracted, ground_truth)]

if log_extra:
log_extra("golden_answer", list(ground_truth))
log_extra("extracted_answer", [e or "[none]" for e in extracted])

if log_metric:
log_metric("accuracy", sum(rewards) / len(rewards))

return rewards
```

#### Passing the reward function to the trainer

To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ Reward functions can be either synchronous Python callables or asynchronous `asy
- `completions` (contains the generated completions),
- `completion_ids` (contains the tokenized completions),
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
- `log_extra`: a callable `log_extra(column: str, values: list)` to add extra columns to the completions table.
- `log_metric`: a callable `log_metric(name: str, value: float)` to log scalar metrics as plots alongside `kl`, `entropy`, etc.
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.

The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
Expand Down
53 changes: 53 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,59 @@ def reward_func(completions, **kwargs):
)
trainer.train()

def test_training_reward_func_with_log_extra(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_extra = kwargs.get("log_extra")
assert log_extra is not None
log_extra("test_column", [completion[:5] for completion in completions])
return [float(len(completion)) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
num_generations=2,
max_completion_length=8,
report_to="none",
log_completions=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert "test_column" in trainer._logs["extra"]

def test_training_reward_func_with_log_metric(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_metric = kwargs.get("log_metric")
assert log_metric is not None
log_metric("custom_accuracy", 0.75)
return [float(len(completion)) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
num_generations=2,
max_completion_length=8,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()
# log_metric appends to _metrics, which gets averaged and merged into log_history
logged_keys = {k for entry in trainer.state.log_history for k in entry}
assert "custom_accuracy" in logged_keys

def test_prepare_input_called_with_correct_data(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
Expand Down
52 changes: 52 additions & 0 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,58 @@ def reward_func(completions, **kwargs):
)
trainer.train()

def test_training_reward_func_with_log_extra(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_extra = kwargs.get("log_extra")
assert log_extra is not None
log_extra("test_column", [completion[:5] for completion in completions])
return [float(len(completion)) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
num_generations=2,
max_completion_length=8,
report_to="none",
log_completions=True,
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert "test_column" in trainer._logs["extra"]

def test_training_reward_func_with_log_metric(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_metric = kwargs.get("log_metric")
assert log_metric is not None
log_metric("custom_accuracy", 0.75)
return [float(len(completion)) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
num_generations=2,
max_completion_length=8,
report_to="none",
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()
logged_keys = {k for entry in trainer.state.log_history for k in entry}
assert "custom_accuracy" in logged_keys

def test_prepare_input_called_with_correct_data(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = RLOOConfig(
Expand Down
26 changes: 26 additions & 0 deletions trl/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,32 @@ class _BaseTrainer(Trainer):
_paper = {}
_template_file = None

def log_completion_extra(self, column: str, values: list):
"""
Log extra columns to the completions table. Called from reward functions via the `log_extra` kwarg.

Args:
column (`str`):
Name of the column to add.
values (`list`):
Values for the column, one per sample in the batch.
"""
self._logs["extra"][column].extend(values)

def log_metric(self, name: str, value: float):
"""
Log a scalar metric from a reward function. Called via the `log_metric` kwarg. Values are averaged
over each logging step and reported alongside built-in metrics like `kl` and `entropy`.

Args:
name (`str`):
Name of the metric.
value (`float`):
Scalar value for this batch.
"""
mode = "train" if self.model.training else "eval"
self._metrics[mode][name].append(value)

def create_model_card(
self,
model_name: str | None = None,
Expand Down
10 changes: 9 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def cast_outputs_to_original_dtype(module, args, output):
"completion": deque(maxlen=args.generation_batch_size),
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
"advantages": deque(maxlen=args.generation_batch_size),
"extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
}

# Ensure each process receives a unique seed to prevent duplicate completions when generating with
Expand Down Expand Up @@ -1140,6 +1141,12 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
# This allows for dynamic reward shaping based on training progress.
reward_kwargs["trainer_state"] = self.state

# Allow reward functions to log extra columns to the completions table.
reward_kwargs["log_extra"] = self.log_completion_extra

# Allow reward functions to log additional scalar metrics.
reward_kwargs["log_metric"] = self.log_metric

async_funcs_info = [] # async custom functions for asyncio.gather

for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
Expand Down Expand Up @@ -1198,7 +1205,7 @@ async def _run_async_funcs():
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
row_reward_kwargs = {
key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state"
key: value[nan_row_idx] for key, value in reward_kwargs.items() if key not in ("trainer_state", "log_extra", "log_metric")
}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
Expand Down Expand Up @@ -2317,6 +2324,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
**self._logs["extra"],
"advantage": self._logs["advantages"],
}

Expand Down
8 changes: 8 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def __init__(
"completion": deque(maxlen=args.generation_batch_size),
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
"advantages": deque(maxlen=args.generation_batch_size),
"extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
}

# Ensure each process receives a unique seed to prevent duplicate completions when generating with
Expand Down Expand Up @@ -814,6 +815,12 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
# This allows for dynamic reward shaping based on training progress.
reward_kwargs["trainer_state"] = self.state

# Allow reward functions to log extra columns to the completions table.
reward_kwargs["log_extra"] = self.log_completion_extra

# Allow reward functions to log additional scalar metrics.
reward_kwargs["log_metric"] = self.log_metric

async_funcs_info = [] # async custom functions for asyncio.gather

for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
Expand Down Expand Up @@ -1414,6 +1421,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
**self._logs["extra"],
"advantage": self._logs["advantages"],
}

Expand Down