-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Description
Feature request
I think logging more values related to vllm importance ratios will be helpful, e.g. logging max/min/mean of raw importance ratios (before truncation/mask) and the fraction of importance samples being truncated/masked.
Motivation
As of now, all the logged values related to vllm importance sampling correction are based on already truncated/masked values, so it is hard to see what the raw values look like and how many importance ratios are truncated/masked.
Where the importance ratios are truncated/masked:
trl/trl/trainer/grpo_trainer.py
Lines 1807 to 1814 in 1850da5
| if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: | |
| vllm_importance_sampling_ratio = torch.clamp( | |
| vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap | |
| ) | |
| elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: | |
| vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( | |
| vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 | |
| ) |
Where the values are logged:
trl/trl/trainer/grpo_trainer.py
Lines 1955 to 1977 in 1850da5
| if sequence_level_is: | |
| flat_is_ratio = vllm_importance_sampling_ratio.flatten() | |
| else: | |
| flat_is_ratio = vllm_importance_sampling_ratio[mask] | |
| min_importance_sampling_ratio = ( | |
| torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) | |
| ) | |
| mean_importance_sampling_ratio = ( | |
| torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) | |
| ) | |
| max_importance_sampling_ratio = ( | |
| torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) | |
| ) | |
| self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( | |
| nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() | |
| ) | |
| self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( | |
| self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() | |
| ) | |
| self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( | |
| nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() | |
| ) |
Your contribution
I can send a PR to add metrics using importance ratios before truncation/masking.
I made a PR: #5243