-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Allow reward functions to log extra columns and scalar metrics #5233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
f5d3975
c8742d5
2e03c46
149447f
382b2b9
44c4c1a
9e262ac
d74a6bd
2d150ed
7127008
188045d
3d5676c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -700,7 +700,11 @@ def cast_outputs_to_original_dtype(module, args, output): | |
| "completion": deque(maxlen=args.generation_batch_size), | ||
| "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), | ||
| "advantages": deque(maxlen=args.generation_batch_size), | ||
| "extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), | ||
| } | ||
| # Buffers for user-logged data from reward functions, flushed after gathering | ||
| self._pending_extra_logs = defaultdict(list) | ||
| self._pending_metrics = defaultdict(list) | ||
|
|
||
| # Ensure each process receives a unique seed to prevent duplicate completions when generating with | ||
| # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but | ||
|
|
@@ -1130,6 +1134,31 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di | |
| inputs = self._generate_and_score_completions(generation_batch) | ||
| return inputs | ||
|
|
||
| def _log_completion_extra(self, column: str, values: list): | ||
| """ | ||
| Log extra columns to the completions table. Called from reward functions via the `log_extra` kwarg. | ||
|
|
||
| Args: | ||
| column (`str`): | ||
| Name of the column to add. | ||
| values (`list`): | ||
| Values for the column, one per sample in the batch. | ||
| """ | ||
| self._pending_extra_logs[column].extend(values) | ||
|
|
||
| def _log_metric(self, name: str, value: float): | ||
| """ | ||
| Log a scalar metric from a reward function. Called via the `log_metric` kwarg. Values are averaged | ||
| over each logging step and reported alongside built-in metrics like `kl` and `entropy`. | ||
|
|
||
| Args: | ||
| name (`str`): | ||
| Name of the metric. | ||
| value (`float`): | ||
| Scalar value for this batch. | ||
| """ | ||
| self._pending_metrics[name].append(value) | ||
|
|
||
| @profiling_decorator | ||
| def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): | ||
| device = self.accelerator.device | ||
|
|
@@ -1142,6 +1171,12 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): | |
| # This allows for dynamic reward shaping based on training progress. | ||
| reward_kwargs["trainer_state"] = self.state | ||
|
|
||
| # Allow reward functions to log extra columns to the completions table. | ||
| reward_kwargs["log_extra"] = self._log_completion_extra | ||
|
|
||
| # Allow reward functions to log additional scalar metrics. | ||
| reward_kwargs["log_metric"] = self._log_metric | ||
|
|
||
| async_funcs_info = [] # async custom functions for asyncio.gather | ||
|
|
||
| for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( | ||
|
|
@@ -1200,7 +1235,9 @@ async def _run_async_funcs(): | |
| if torch.isnan(rewards_per_func).all(dim=1).any(): | ||
| nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] | ||
| row_reward_kwargs = { | ||
| key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state" | ||
| key: value[nan_row_idx] | ||
| for key, value in reward_kwargs.items() | ||
| if key not in ("trainer_state", "log_extra", "log_metric") | ||
| } | ||
| row_reward_kwargs["prompt"] = prompts[nan_row_idx] | ||
| row_reward_kwargs["completion"] = completions[nan_row_idx] | ||
|
|
@@ -1939,6 +1976,18 @@ def _generate_and_score_completions( | |
| self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) | ||
| self._logs["advantages"].extend(all_process_advantages.tolist()) | ||
|
|
||
| # Flush user-logged extra columns (from log_extra), gathering across processes | ||
| for column, values in self._pending_extra_logs.items(): | ||
| self._logs["extra"][column].extend(gather_object(values)) | ||
qgallouedec marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._pending_extra_logs.clear() | ||
|
|
||
| # Flush user-logged metrics (from log_metric), averaging across processes | ||
| for name, values in self._pending_metrics.items(): | ||
| local_mean = sum(values) / len(values) | ||
| global_mean = self.accelerator.gather(torch.tensor(local_mean, device=device)).mean().item() | ||
| self._metrics[mode][name].append(global_mean) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. User metrics can silently corrupt built-in metricsMedium Severity
Additional Locations (1) |
||
| self._pending_metrics.clear() | ||
|
|
||
| if images is not None: | ||
| self._logs["images"].extend(gather_object(images)) | ||
|
|
||
|
|
@@ -2319,6 +2368,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: | |
| "prompt": self._logs["prompt"], | ||
| "completion": self._logs["completion"], | ||
| **self._logs["rewards"], | ||
| **self._logs["extra"], | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra columns can overwrite core table columnsMedium Severity
Additional Locations (1) |
||
| "advantage": self._logs["advantages"], | ||
| } | ||
|
|
||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.