Skip to content
24 changes: 24 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ Reward functions can be either synchronous Python callables or asynchronous `asy
- `completions` (contains the generated completions),
- `completion_ids` (contains the tokenized completions),
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
- `log_extra`: a callable `log_extra(column: str, values: list)` to add extra columns to the completions table. See Example 6.
- `log_metric`: a callable `log_metric(name: str, value: float)` to log scalar metrics as plots alongside `kl`, `entropy`, etc. See Example 6.
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.

The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
Expand Down Expand Up @@ -558,6 +560,28 @@ async def async_reward_func(prompts, completions, **kwargs):
return [1.0 if completion else 0.0 for completion in completions]
```

#### Example 6: Logging extra columns and metrics

Below is an example of a reward function that logs extra columns to the completions table and scalar metrics as plots.

```python
import re

def reward_func(completions, ground_truth, log_extra=None, log_metric=None, **kwargs):
extracted = [re.search(r"\\boxed\{(.*?)\}", c) for c in completions]
extracted = [m.group(1) if m else None for m in extracted]
rewards = [1.0 if e == gt else 0.0 for e, gt in zip(extracted, ground_truth)]

if log_extra:
log_extra("golden_answer", list(ground_truth))
log_extra("extracted_answer", [e or "[none]" for e in extracted])

if log_metric:
log_metric("accuracy", sum(rewards) / len(rewards))

return rewards
```

#### Passing the reward function to the trainer

To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
Expand Down
24 changes: 24 additions & 0 deletions docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ Reward functions can be either synchronous Python callables or asynchronous `asy
- `completions` (contains the generated completions),
- `completion_ids` (contains the tokenized completions),
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
- `log_extra`: a callable `log_extra(column: str, values: list)` to add extra columns to the completions table. See Example 6.
- `log_metric`: a callable `log_metric(name: str, value: float)` to log scalar metrics as plots alongside `kl`, `entropy`, etc. See Example 6.
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.

The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
Expand Down Expand Up @@ -482,6 +484,28 @@ async def async_reward_func(prompts, completions, **kwargs):
return [1.0 if completion else 0.0 for completion in completions]
```

#### Example 6: Logging extra columns and metrics

Below is an example of a reward function that logs extra columns to the completions table and scalar metrics as plots.

```python
import re

def reward_func(completions, ground_truth, log_extra=None, log_metric=None, **kwargs):
extracted = [re.search(r"\\boxed\{(.*?)\}", c) for c in completions]
extracted = [m.group(1) if m else None for m in extracted]
rewards = [1.0 if e == gt else 0.0 for e, gt in zip(extracted, ground_truth)]

if log_extra:
log_extra("golden_answer", list(ground_truth))
log_extra("extracted_answer", [e or "[none]" for e in extracted])

if log_metric:
log_metric("accuracy", sum(rewards) / len(rewards))

return rewards
```

#### Passing the reward function to the trainer

To use your custom reward function, pass it to the [`RLOOTrainer`] as follows:
Expand Down
59 changes: 56 additions & 3 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,9 +1763,59 @@ def reward_func(completions, **kwargs):

training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
num_generations=2,
max_completion_length=8,
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()

def test_training_reward_func_with_log_extra(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_extra = kwargs.get("log_extra")
assert log_extra is not None
log_extra("test_column", [completion[:5] for completion in completions])
return [float(len(completion)) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
log_completions=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert "test_column" in trainer._logs["extra"]

def test_training_reward_func_with_log_metric(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_metric = kwargs.get("log_metric")
assert log_metric is not None
log_metric("custom_accuracy", 0.75)
return [float(len(completion)) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
Expand All @@ -1775,6 +1825,9 @@ def reward_func(completions, **kwargs):
train_dataset=dataset,
)
trainer.train()
# log_metric appends to _metrics, which gets averaged and merged into log_history
logged_keys = {k for entry in trainer.state.log_history for k in entry}
assert "custom_accuracy" in logged_keys

def test_prepare_input_called_with_correct_data(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
Expand Down
59 changes: 56 additions & 3 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,9 +1156,59 @@ def reward_func(completions, **kwargs):

training_args = RLOOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=2,
num_generations=2,
max_completion_length=8,
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()

def test_training_reward_func_with_log_extra(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_extra = kwargs.get("log_extra")
assert log_extra is not None
log_extra("test_column", [completion[:5] for completion in completions])
return [float(len(completion)) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
log_completions=True,
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert "test_column" in trainer._logs["extra"]

def test_training_reward_func_with_log_metric(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(completions, **kwargs):
log_metric = kwargs.get("log_metric")
assert log_metric is not None
log_metric("custom_accuracy", 0.75)
return [float(len(completion)) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = RLOOTrainer(
Expand All @@ -1168,6 +1218,9 @@ def reward_func(completions, **kwargs):
train_dataset=dataset,
)
trainer.train()
# log_metric appends to _metrics, which gets averaged and merged into log_history
logged_keys = {k for entry in trainer.state.log_history for k in entry}
assert "custom_accuracy" in logged_keys

def test_prepare_input_called_with_correct_data(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
Expand Down
52 changes: 51 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,11 @@ def cast_outputs_to_original_dtype(module, args, output):
"completion": deque(maxlen=args.generation_batch_size),
"rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
"advantages": deque(maxlen=args.generation_batch_size),
"extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
}
# Buffers for user-logged data from reward functions, flushed after gathering
self._pending_extra_logs = defaultdict(list)
self._pending_metrics = defaultdict(list)

# Ensure each process receives a unique seed to prevent duplicate completions when generating with
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
Expand Down Expand Up @@ -1130,6 +1134,31 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di
inputs = self._generate_and_score_completions(generation_batch)
return inputs

def _log_completion_extra(self, column: str, values: list):
"""
Log extra columns to the completions table. Called from reward functions via the `log_extra` kwarg.

Args:
column (`str`):
Name of the column to add.
values (`list`):
Values for the column, one per sample in the batch.
"""
self._pending_extra_logs[column].extend(values)

def _log_metric(self, name: str, value: float):
"""
Log a scalar metric from a reward function. Called via the `log_metric` kwarg. Values are averaged
over each logging step and reported alongside built-in metrics like `kl` and `entropy`.

Args:
name (`str`):
Name of the metric.
value (`float`):
Scalar value for this batch.
"""
self._pending_metrics[name].append(value)

@profiling_decorator
def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
device = self.accelerator.device
Expand All @@ -1142,6 +1171,12 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
# This allows for dynamic reward shaping based on training progress.
reward_kwargs["trainer_state"] = self.state

# Allow reward functions to log extra columns to the completions table.
reward_kwargs["log_extra"] = self._log_completion_extra

# Allow reward functions to log additional scalar metrics.
reward_kwargs["log_metric"] = self._log_metric

async_funcs_info = [] # async custom functions for asyncio.gather

for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
Expand Down Expand Up @@ -1200,7 +1235,9 @@ async def _run_async_funcs():
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
row_reward_kwargs = {
key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state"
key: value[nan_row_idx]
for key, value in reward_kwargs.items()
if key not in ("trainer_state", "log_extra", "log_metric")
}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
Expand Down Expand Up @@ -1939,6 +1976,18 @@ def _generate_and_score_completions(
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_process_advantages.tolist())

# Flush user-logged extra columns (from log_extra), gathering across processes
for column, values in self._pending_extra_logs.items():
self._logs["extra"][column].extend(gather_object(values))
self._pending_extra_logs.clear()

# Flush user-logged metrics (from log_metric), averaging across processes
for name, values in self._pending_metrics.items():
local_mean = sum(values) / len(values)
global_mean = self.accelerator.gather(torch.tensor(local_mean, device=device)).mean().item()
self._metrics[mode][name].append(global_mean)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User metrics can silently corrupt built-in metrics

Medium Severity

log_metric writes into self._metrics[mode][name], which is the same namespace as built-in metrics like "reward", "reward_std", "kl", "entropy", "frac_reward_zero_std", etc. If a user calls log_metric("reward", 0.75), that value is appended to the same list where the built-in reward mean is stored. At logging time, all values in the list are averaged together via sum(val) / len(val), silently corrupting the built-in metric with no warning or validation.

Additional Locations (1)

Fix in Cursor Fix in Web

self._pending_metrics.clear()

if images is not None:
self._logs["images"].extend(gather_object(images))

Expand Down Expand Up @@ -2319,6 +2368,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
**self._logs["extra"],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra columns can overwrite core table columns

Medium Severity

**self._logs["extra"] is spread into the completions table dict after "step", "prompt", "completion", and **self._logs["rewards"]. A user calling log_extra("prompt", values) or log_extra("step", values) silently overwrites those core columns. Conversely, log_extra("advantage", values) is silently discarded because the "advantage" key is defined afterward. No validation in _log_completion_extra guards against reserved column names.

Additional Locations (1)

Fix in Cursor Fix in Web

"advantage": self._logs["advantages"],
}

Expand Down
Loading
Loading