Skip to content

Logging how vllm importance ratios are truncated/masked in GRPOTrainer #5231

@muupan

Description

@muupan

Feature request

I think logging more values related to vllm importance ratios will be helpful, e.g. logging max/min/mean of raw importance ratios (before truncation/mask) and the fraction of importance samples being truncated/masked.

Motivation

As of now, all the logged values related to vllm importance sampling correction are based on already truncated/masked values, so it is hard to see what the raw values look like and how many importance ratios are truncated/masked.

Where the importance ratios are truncated/masked:

if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]:
vllm_importance_sampling_ratio = torch.clamp(
vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]:
vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill(
vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0
)

Where the values are logged:

if sequence_level_is:
flat_is_ratio = vllm_importance_sampling_ratio.flatten()
else:
flat_is_ratio = vllm_importance_sampling_ratio[mask]
min_importance_sampling_ratio = (
torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
mean_importance_sampling_ratio = (
torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
max_importance_sampling_ratio = (
torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
)
self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
)
self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
)

Your contribution

I can send a PR to add metrics using importance ratios before truncation/masking.
I made a PR: #5243

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions