Skip to content

feat: log raw importance ratios and fraction of truncation/masking in vLLM importance sampling correction#5243

Open
muupan wants to merge 4 commits intohuggingface:mainfrom
muupan:feature/log-raw-importance-sampling-ratio
Open

feat: log raw importance ratios and fraction of truncation/masking in vLLM importance sampling correction#5243
muupan wants to merge 4 commits intohuggingface:mainfrom
muupan:feature/log-raw-importance-sampling-ratio

Conversation

@muupan
Copy link
Contributor

@muupan muupan commented Mar 8, 2026

What does this PR do?

Resolves #5231

This PR adds new logged metrics:

  • sampling/raw_importance_sampling_ratio/{min,max,mean}
    • These are the same with existing sampling/importance_sampling_ratio/{min,max,mean} except being computed with values before truncation or masking.
  • sampling/frac_modified_importance_sampling_ratio
    • This is the fraction of importance sampling ratio values that are either truncated or masked.

I ran the following code to verify the output:

import argparse

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
from trl.rewards import accuracy_reward


parser = argparse.ArgumentParser()
parser.add_argument("--vllm_importance_sampling_mode", type=str, default="sequence_mask")
args = parser.parse_args()

dataset = load_dataset("trl-lib/DeepMath-103K", split="train[:5]")

args = GRPOConfig(
    output_dir=f"outputs/{args.vllm_importance_sampling_mode}",
    vllm_importance_sampling_correction=True,
    vllm_importance_sampling_mode=args.vllm_importance_sampling_mode,
    vllm_importance_sampling_cap=1.2,
    use_vllm=True,
    vllm_mode="colocate",
    max_steps=5,
    num_train_epochs=1,
    logging_steps=1,
    save_strategy="no",
    report_to="tensorboard",
)

trainer = GRPOTrainer(
    args=args,
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()
uv run --no-sync accelerate launch --num_processes 1 train_grpo_example.py --vllm_importance_sampling_mode token_truncate
uv run --no-sync accelerate launch --num_processes 1 train_grpo_example.py --vllm_importance_sampling_mode token_mask
uv run --no-sync accelerate launch --num_processes 1 train_grpo_example.py --vllm_importance_sampling_mode sequence_truncate
uv run --no-sync accelerate launch --num_processes 1 train_grpo_example.py --vllm_importance_sampling_mode sequence_mask

From the tensorboard records, you can see:

  • The new metrics are all recorded.
  • sampling/raw_importance_sampling_ratio/max are higher than the cap value of 1.2, while sampling/importance_sampling_ratio/max is upper bounded by it, which is expected as the latter are affected by truncation or masking. Masking leads to lower values than truncation, which is also expected.
  • sampling/frac_modified_importance_sampling_ratio is higher with sequence-level IS than with token-level IS.
Screenshot 2026-03-09 at 0 49 18 Screenshot 2026-03-09 at 0 49 05

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.


Note

Low Risk
Low risk: changes are limited to additional metric logging around vLLM importance sampling ratios and do not affect the actual loss computation beyond variable renaming.

Overview
Adds new GRPOTrainer metrics to better diagnose vLLM importance-sampling correction by logging pre-cap (“raw”) importance sampling ratio stats (sampling/raw_importance_sampling_ratio/{min,mean,max}) alongside the existing post-cap stats.

Also logs sampling/frac_modified_importance_sampling_ratio, the fraction of ratios that were changed by truncation/masking, and refactors the computation to keep both raw_vllm_importance_sampling_ratio and the capped/masked vllm_importance_sampling_ratio for logging.

Written by Cursor Bugbot for commit 7f8fee3. This will update automatically on new commits. Configure here.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Logging how vllm importance ratios are truncated/masked in GRPOTrainer

2 participants