Skip to content

Allow reward functions to log extra columns and scalar metrics#5233

Open
manueldeprada wants to merge 12 commits intohuggingface:mainfrom
manueldeprada:extra_logs
Open

Allow reward functions to log extra columns and scalar metrics#5233
manueldeprada wants to merge 12 commits intohuggingface:mainfrom
manueldeprada:extra_logs

Conversation

@manueldeprada
Copy link

@manueldeprada manueldeprada commented Mar 6, 2026

TLDR

Adds log_extra and log_metric hooks to reward functions in GRPO and RLOO so custom reward functions can log extra completion columns and scalar metrics without accessing private trainer state.

image image

So:

While working with custom reward functions (e.g. for calibration training), I found myself using log_completions=True, which logs samples to trackio (super useful!).

However, I also needed to log additional columns to the completions table (like extracted answers and gold labels) and scalar metrics (like accuracy and ECE) directly from within the reward function. Currently there’s no clean way to do this in GRPO without accessing private state, to the best of my knowledge.

This PR adds two public methods to _BaseTrainer and passes them to reward functions via reward_kwargs:

  • log_extra(column, values): adds extra columns to the completions table (parquet + wandb/trackio)
  • log_metric(name, value): logs scalar metrics through the existing _metrics mechanism, so they show up as plots alongside KL, entropy, etc.

Both are fully backwards compatible, as existing reward functions will just absorb them into **kwargs. It is not super orthodox to pass functions like this, but it is simple and adds very little code.. I also added them to RLOO, as it is similar, for completeness.

Example usage

def my_reward_fn(completions, answer, log_extra=None, log_metric=None, **kwargs):
    extracted = [extract_answer(c) for c in completions]
    rewards = [1.0 if e == a else 0.0 for e, a in zip(extracted, answer)]

    if log_extra:
        log_extra("golden_answer", list(answer))
        log_extra("extracted_answer", extracted)

    if log_metric:
        log_metric("accuracy", sum(rewards) / len(rewards))

    return rewards

Docs, tests, and signatures were Clauded, so lmk if you’d be interested in merging it and I can put more effort into polishing the PR.

Shoutout to the berner guys, Andy, @lewtun, joel and @lvwerra. Hope everything is going well, its also fun to find myself using TRL from the other side 🙂


Note

Medium Risk
Touches core training/logging paths for GRPO and RLOO, including distributed gather ordering and metric aggregation; errors could misattribute logged values or break logging, but training behavior is otherwise unchanged.

Overview
Adds two new optional kwargs passed into GRPO/RLOO custom reward functions: log_extra(column, values) for emitting additional per-completion columns into the logged completions table, and log_metric(name, value) for emitting scalar metrics that get averaged and reported alongside built-in training curves.

Implements buffered per-step collection and distributed-safe flushing (sorted keys + gather/mean) in GRPOTrainer and RLOOTrainer, and includes the new extra columns when writing/parquet-logging completions. Updates docs with a new example and extends unit tests to cover both hooks (plus tweaks test configs to reduce memory usage).

Written by Cursor Bugbot for commit 3d5676c. This will update automatically on new commits. Configure here.

@qgallouedec
Copy link
Member

Thanks @manueldeprada!

I understand the need, there's currently no simple way to log extra metrics from within a reward function without modifying the codebase. The closest workaround today is to use separate reward functions with reward_weights=0 for pure logging, but that doesn't help when the metric is a byproduct of the reward computation and you want to avoid recomputing it.

That said, I'm not fully sold on the approach. A reward function should ideally be passive: it takes inputs, returns a scalar, and has no side effects. Passing a log_metric callback blurs that boundary.

I don't have a better alternative in mind though, and from the user's perspective, this is probably the most intuitive and simple solution. So I'd consider merging it. Open to thoughts from others. @albertvillanova @AmineDiro.

@qgallouedec
Copy link
Member

@codex review

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 71270089a6

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

values = self._pending_metrics[name]
local_mean = sum(values) / len(values)
global_mean = self.accelerator.gather(torch.tensor(local_mean, device=device)).mean().item()
self._metrics[mode][name].append(global_mean)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User metrics can silently corrupt built-in metrics

Medium Severity

log_metric writes into self._metrics[mode][name], which is the same namespace as built-in metrics like "reward", "reward_std", "kl", "entropy", "frac_reward_zero_std", etc. If a user calls log_metric("reward", 0.75), that value is appended to the same list where the built-in reward mean is stored. At logging time, all values in the list are averaged together via sum(val) / len(val), silently corrupting the built-in metric with no warning or validation.

Additional Locations (1)

Fix in Cursor Fix in Web

"prompt": self._logs["prompt"],
"completion": self._logs["completion"],
**self._logs["rewards"],
**self._logs["extra"],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra columns can overwrite core table columns

Medium Severity

**self._logs["extra"] is spread into the completions table dict after "step", "prompt", "completion", and **self._logs["rewards"]. A user calling log_extra("prompt", values) or log_extra("step", values) silently overwrites those core columns. Conversely, log_extra("advantage", values) is silently discarded because the "advantage" key is defined afterward. No validation in _log_completion_extra guards against reserved column names.

Additional Locations (1)

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants