Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,19 +1798,19 @@ def _generate_and_score_completions(
else:
logps_diff = per_token_logps_diff

vllm_importance_sampling_ratio = torch.exp(logps_diff)
raw_vllm_importance_sampling_ratio = torch.exp(logps_diff)

# vllm_importance_sampling_ratio.shape:
# token_* modes: (B, T) (per-token ratio)
# sequence_* modes: (B, 1) (per-sequence ratio)

if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]:
vllm_importance_sampling_ratio = torch.clamp(
vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
raw_vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]:
vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill(
vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0
vllm_importance_sampling_ratio = raw_vllm_importance_sampling_ratio.masked_fill(
raw_vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0
)
else:
raise ValueError(
Expand Down Expand Up @@ -1954,9 +1954,12 @@ def _generate_and_score_completions(
)
if sequence_level_is:
flat_is_ratio = vllm_importance_sampling_ratio.flatten()
raw_flat_is_ratio = raw_vllm_importance_sampling_ratio.flatten()
else:
flat_is_ratio = vllm_importance_sampling_ratio[mask]
raw_flat_is_ratio = raw_vllm_importance_sampling_ratio[mask]

# Stats related to importance sampling ratio after masking/truncation
min_importance_sampling_ratio = (
torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
Expand All @@ -1976,6 +1979,29 @@ def _generate_and_score_completions(
nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
)

# Stats related to importance sampling ratio before masking/truncation
min_raw_importance_sampling_ratio = (
torch.min(raw_flat_is_ratio) if raw_flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
mean_raw_importance_sampling_ratio = (
torch.mean(raw_flat_is_ratio) if raw_flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
max_raw_importance_sampling_ratio = (
torch.max(raw_flat_is_ratio) if raw_flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
self._metrics[mode]["sampling/raw_importance_sampling_ratio/min"].append(
nanmin(self.accelerator.gather(min_raw_importance_sampling_ratio)).item()
)
self._metrics[mode]["sampling/raw_importance_sampling_ratio/mean"].append(
self.accelerator.gather(mean_raw_importance_sampling_ratio).nanmean().item()
)
self._metrics[mode]["sampling/raw_importance_sampling_ratio/max"].append(
nanmax(self.accelerator.gather(max_raw_importance_sampling_ratio)).item()
)
self._metrics[mode]["sampling/frac_modified_importance_sampling_ratio"].append(
self.accelerator.gather(torch.ne(flat_is_ratio, raw_flat_is_ratio).float().mean()).item()
)

output = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
Expand Down