Skip to content

Conversation

@erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Jan 7, 2026

Overview

  • Marks trainer.algorithm.use_tis and trainer.algorithm.tis_imp_ratio_cap for deprecation
  • Introduces new trainer.algorithm.off_policy_correction config (see new config below)
  • Updates loss functions to return a LossMetrics TypedDict containing loss metrics (previously returned just loss, clip_ratio)
  • Updates workers to all reduce mean/max/min appropriately, and to propagate loss metrics back up to the trainer.

Off Policy Correction Config

# To be deprecated in favor of off_policy_correction.tis_ratio_type = "token"
# and "token_tis_ratio_clip_high"
tis_imp_ratio_cap: -1.0
use_tis: false

off_policy_correction:
      # type of importance sampling ratio to use for ppo loss correction
      # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy})
      tis_ratio_type: null # null, "token", "sequence"

      # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type
      token_tis_ratio_clip_high: 2.0
      # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type
      sequence_tis_ratio_clip_high: 5.0

      # method of masking out sequences with cumulative importance sampling ratios outside the cap
      # "product" masks out sequences with product of importance ratios outside the cap
      # "geometric" masks out sequences with geometric mean of importance ratios outside the cap
      sequence_mask_metric: null # null, "product", "geometric"

      # used if sequence_mask_metric = "geometric"
      # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch
      geo_mask_high: 1.01
      geo_mask_low: 0.99

      # used if sequence_mask_metric = "product"
      # values around 0.5-2.0 are recommended for "sequence" sequence_mask_metric
      product_mask_high: 2.0
      product_mask_low: 0.5

      # separate from sequence_mask_metric and tis_ratio_type 
      # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio
      # far outside an acceptable range (low and high thresholds)
      outlier_token_is_threshold_low: 1e-4
      outlier_token_is_threshold_high: 100

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Truncated Importance Sampling (TIS) configuration into a more comprehensive rollout_correction system, which is a great improvement for structure and extensibility. The new implementation adds flexible rollout correction mechanisms, including different TIS ratio types and rejection masks. The changes are well-documented and handle the deprecation of old parameters gracefully. I've identified a bug in a conditional check that could cause a crash, and an opportunity to refactor for better efficiency and code clarity. My detailed feedback is in the comments below.

@erictang000 erictang000 changed the title [skyrl-train] Refactor TIS to use more comprehensive rollout correction config [skyrl-train] Refactor TIS to use more comprehensive off policy correction config Jan 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant