Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f033e65
x
erictang000 Jan 7, 2026
0b236fe
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 7, 2026
3f3b759
x
erictang000 Jan 7, 2026
29efd6f
x
erictang000 Jan 7, 2026
1520157
x
erictang000 Jan 7, 2026
45a59c2
x
erictang000 Jan 7, 2026
ce01bb2
fix tests and add rollout correction to other loss types
erictang000 Jan 8, 2026
abac800
add metrics
erictang000 Jan 8, 2026
2dc7364
propagate metrics up and refactor how we do metric reductions for max…
erictang000 Jan 8, 2026
349369d
make default null and propagate megatron metrics
erictang000 Jan 8, 2026
f3f7054
x:
erictang000 Jan 8, 2026
c45c130
Merge branch 'rollout_correction' of https://github.com/erictang000/S…
erictang000 Jan 8, 2026
63d38c5
big cleanup - remove clip_ratio return (fix custom algorithms stuff),…
erictang000 Jan 8, 2026
7e83c10
x
erictang000 Jan 8, 2026
cf042fc
renaming
erictang000 Jan 8, 2026
cef7121
x
erictang000 Jan 8, 2026
9485bdd
x
erictang000 Jan 8, 2026
9e11eda
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 12, 2026
c06747c
x
erictang000 Jan 12, 2026
0b5ebfd
x
erictang000 Jan 12, 2026
0697957
x
erictang000 Jan 13, 2026
6b9e1e4
add docs
erictang000 Jan 13, 2026
46b6fe5
x
erictang000 Jan 13, 2026
d72d9c6
x
erictang000 Jan 13, 2026
db76d01
gemini:
erictang000 Jan 13, 2026
ac0659c
x
erictang000 Jan 13, 2026
7ddb85f
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 21, 2026
6c8d084
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 21, 2026
08c3625
x
erictang000 Jan 21, 2026
2bca41f
x
erictang000 Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,40 @@ trainer:
eps_clip_high: 0.2
# dual clip parameters
clip_ratio_c: 3.0
# Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl

# mark for deprecation
tis_imp_ratio_cap: -1.0
use_tis: false

# reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
rollout_correction:
# type of importance ratio to use for ppo loss correction
# here importance ratio refers to logprobs_{rollout} - logprobs_{policy_old}
tis_ratio_type: "null" # "null", "token", "sequence"

# cap for the importance ratio
# 1.5-5.0 is recommended for "token" tis_ratio_type
token_tis_ratio_cap_high: 2.0
# 2.0-10.0 is recommended for "sequence" tis_ratio_type
sequence_tis_ratio_cap_high: 5.0

# "sequence" mask masks out sequences with product of importance ratios outside the cap
# "geometric" mask masks out sequences with geometric mean of importance ratios outside the cap
rejection_mask_type: "null" # "null", "sequence", "geometric"

# cap for the rejection mask ratio
# values around 0.99-1.01 are recommended for "geometric" rejection_mask_type - MoE models may need larger allowed ranges due to higher mismatch
geo_rejection_mask_ratio_cap_high: 1.001
geo_rejection_mask_ratio_cap_low: 0.999

# sequence level rejection mask ratio cap
sequence_rejection_mask_ratio_cap_high: 2.0
sequence_rejection_mask_ratio_cap_low: 0.5

# masks out sequences with any token having importance ratio below the cap
sequence_mask_low: 1e-4
sequence_mask_high: 100

# SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347)
sapo:
tau_pos: 1.0
Expand Down
13 changes: 12 additions & 1 deletion skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,12 +579,23 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
loss_masks,
logprobs,
)
# sanity check for tis
# sanity check for tis (legacy)
if self.cfg.trainer.algorithm.use_tis:
assert (
rollout_logprobs_tensor is not None
), "expected non-null rollout logprobs tensor with `trainer.algorithm.use_tis` as `True`"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

# sanity check for rollout_correction
rollout_corr = self.cfg.trainer.algorithm.rollout_correction
tis_ratio_type = rollout_corr.tis_ratio_type
rejection_mask_type = rollout_corr.rejection_mask_type
if tis_ratio_type is not None or rejection_mask_type is not None:
assert (
rollout_logprobs_tensor is not None
), "expected non-null rollout logprobs tensor when rollout_correction is enabled"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
Expand Down
152 changes: 152 additions & 0 deletions skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,152 @@ def _safe_exp_delta(delta: torch.Tensor, clip: float = 20.0, out_dtype=None) ->
return y.to(out_dtype or delta.dtype)


def compute_tis_ratio(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are great! Can we put them into a separate file? Our ppo_utils.py is 1.4k LOCs now.

In long term we could break ppo_utils.py down, but for now let's create a file of off_policy_correction_utils.py (or some other name you see fit) with all these methods you added. We can keep the rest in where they currently are and come back later if we'd want to further clean up.

old_log_probs: torch.Tensor,
rollout_logprobs: torch.Tensor,
loss_mask: torch.Tensor,
tis_ratio_type: str,
rollout_corr: DictConfig,
) -> torch.Tensor:
"""
Compute truncated importance sampling (TIS) ratio for rollout correction.

Args:
old_log_probs: Log probabilities from the old policy (before update).
rollout_logprobs: Log probabilities from the rollout policy.
loss_mask: Mask indicating valid tokens.
tis_ratio_type: Type of TIS ratio ("token" or "sequence").
rollout_corr: Rollout correction config containing cap values.

Returns:
TIS ratio tensor to multiply with the loss.

Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
"""
# Compute token-level importance ratio: pi_old / pi_rollout
# In log space: old_log_probs - rollout_logprobs
token_tis_log_ratio = old_log_probs - rollout_logprobs
token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)

if tis_ratio_type == "token":
token_tis_ratio_cap = rollout_corr.token_tis_ratio_cap_high
return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap)
elif tis_ratio_type == "sequence":
# Compute sequence-level importance ratio as product of token ratios (sum of log ratios)
seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True)
seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)
seq_tis_ratio_cap = rollout_corr.sequence_tis_ratio_cap_high
return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap)
else:
raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}")


def compute_rejection_mask(
old_log_probs: torch.Tensor,
rollout_logprobs: torch.Tensor,
loss_mask: torch.Tensor,
rejection_mask_type: str,
rollout_corr: DictConfig,
) -> torch.Tensor:
"""
Compute rejection mask for rollout correction.

This masks out sequences with importance ratios that fall outside acceptable bounds,
helping to filter out off-policy samples that may destabilize training.

Args:
old_log_probs: Log probabilities from the old policy (before update).
rollout_logprobs: Log probabilities from the rollout policy.
loss_mask: Mask indicating valid tokens.
rejection_mask_type: Type of rejection mask ("geometric" or "sequence").
rollout_corr: Rollout correction config containing cap values.

Returns:
Rejection mask tensor (float) to multiply with the loss.

Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
"""
# Compute token-level importance ratio
token_tis_log_ratio = old_log_probs - rollout_logprobs
token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)

if rejection_mask_type == "geometric":
# Compute geometric mean of importance ratios per sequence
num_tokens = loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0)
seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True)
geo_mean_ratio = _safe_exp_delta(seq_tis_log_ratio / num_tokens, clip=20.0, out_dtype=old_log_probs.dtype)
geo_cap_high = rollout_corr.geo_rejection_mask_ratio_cap_high
geo_cap_low = rollout_corr.geo_rejection_mask_ratio_cap_low
geo_rejection_mask = (geo_mean_ratio >= geo_cap_low) & (geo_mean_ratio <= geo_cap_high)
return geo_rejection_mask.float()
elif rejection_mask_type == "sequence":
# Mask out sequences with product of importance ratios outside the cap
seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True)
seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)
seq_cap_high = rollout_corr.sequence_rejection_mask_ratio_cap_high
seq_cap_low = rollout_corr.sequence_rejection_mask_ratio_cap_low
# Also check per-token bounds
token_mask_low = rollout_corr.sequence_mask_low
token_mask_high = rollout_corr.sequence_mask_high
token_in_bounds = (token_tis_ratio >= token_mask_low) & (token_tis_ratio <= token_mask_high)
# A sequence is valid if all tokens are in bounds (considering only masked positions)
all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True)
seq_in_bounds = (seq_tis_ratio >= seq_cap_low) & (seq_tis_ratio <= seq_cap_high)
seq_rejection_mask = seq_in_bounds & all_tokens_valid
return seq_rejection_mask.float()
else:
raise ValueError(f"Unknown rejection_mask_type: {rejection_mask_type}")


def apply_rollout_correction(
loss: torch.Tensor,
old_log_probs: torch.Tensor,
rollout_logprobs: torch.Tensor,
loss_mask: torch.Tensor,
rollout_corr: DictConfig,
) -> torch.Tensor:
"""
Apply rollout correction to the loss using TIS ratio and rejection mask.

This is a convenience function that combines compute_tis_ratio and compute_rejection_mask.

Args:
loss: The loss tensor to correct.
old_log_probs: Log probabilities from the old policy (before update).
rollout_logprobs: Log probabilities from the rollout policy.
loss_mask: Mask indicating valid tokens.
rollout_corr: Rollout correction config.

Returns:
Corrected loss tensor.
"""
tis_ratio_type = rollout_corr.tis_ratio_type
rejection_mask_type = rollout_corr.rejection_mask_type

# Check if TIS ratio correction is enabled
apply_tis = tis_ratio_type is not None and tis_ratio_type != "null"
# Check if rejection mask is enabled
apply_rejection = rejection_mask_type is not None and rejection_mask_type != "null"

# Early return if no correction needed
if not apply_tis and not apply_rejection:
return loss

# Apply TIS ratio if enabled
if apply_tis:
tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr)
loss = loss * tis_ratio

# Apply rejection mask if enabled
if apply_rejection:
rejection_mask = compute_rejection_mask(
old_log_probs, rollout_logprobs, loss_mask, rejection_mask_type, rollout_corr
)
loss = loss * rejection_mask

return loss


@register_policy_loss(PolicyLossType.REGULAR)
@register_policy_loss(PolicyLossType.DUAL_CLIP)
def ppo_policy_loss(
Copy link
Collaborator

@CharlieFRuan CharlieFRuan Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return is typed as Tuple[torch.Tensor, float], which isn't correct right, due to it currently returning loss_metrics. Depending on what we do with LossMetrics as noted in the other comment on LossMetrics, we could make it dict[str, float]

Expand Down Expand Up @@ -581,6 +727,7 @@ def ppo_policy_loss(
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

# Legacy TIS support (deprecated)
if config.use_tis:
from loguru import logger as logger_

Expand All @@ -590,6 +737,11 @@ def ppo_policy_loss(
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
loss = loss * tis_imp_ratio

# New rollout correction support
rollout_corr = config.get("rollout_correction", None)
if rollout_corr is not None and rollout_logprobs is not None:
loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
return loss, clip_ratio

Expand Down
45 changes: 45 additions & 0 deletions skyrl-train/skyrl_train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,11 @@ def validate_cfg(cfg: DictConfig):
algorithm_config.kl_estimator_type = "k3"
cfg.trainer.algorithm = algorithm_config

# Legacy TIS validation (deprecated)
if cfg.trainer.algorithm.use_tis:
logger.warning(
"`trainer.algorithm.use_tis` is deprecated. Please use `trainer.algorithm.rollout_correction` instead."
)
if cfg.trainer.algorithm.tis_imp_ratio_cap <= 0:
raise ValueError(
f"If `trainer.algorithm.use_tis` is `True` then `cfg.trainer.algorithm.tis_imp_ratio_cap` "
Expand All @@ -302,6 +306,47 @@ def validate_cfg(cfg: DictConfig):
"dual_clip",
], "TIS is only implemented for regular and dual_clip policy loss types"

# New rollout_correction validation
rollout_corr = cfg.trainer.algorithm.get("rollout_correction", None)
if rollout_corr is not None:
tis_ratio_type = rollout_corr.get("tis_ratio_type", "null")
rejection_mask_type = rollout_corr.get("rejection_mask_type", "null")

uses_rollout_correction = tis_ratio_type != "null" or rejection_mask_type != "null"

if uses_rollout_correction:
# Validate tis_ratio_type
assert tis_ratio_type in [
"null",
"token",
"sequence",
], f"`tis_ratio_type` must be 'null', 'token', or 'sequence', got {tis_ratio_type}"

# Validate rejection_mask_type
assert rejection_mask_type in [
"null",
"sequence",
"geometric",
], f"`rejection_mask_type` must be 'null', 'sequence', or 'geometric', got {rejection_mask_type}"

# Ensure logprobs are enabled for rollout correction
if cfg.generator.sampling_params.logprobs is None:
logger.warning(
"`generator.sampling_params.logprobs` is `None` but rollout_correction is enabled."
" Setting `logprobs` to `True`."
)
cfg.generator.sampling_params.logprobs = 0

if cfg.generator.backend == "sglang":
raise NotImplementedError(
"`trainer.algorithm.rollout_correction` doesn't support Sglang backend, please use vLLM"
)

assert cfg.trainer.algorithm.policy_loss_type in [
"regular",
"dual_clip",
], "rollout_correction is only implemented for regular and dual_clip policy loss types"

if cfg.trainer.policy.model.lora.rank > 0:
# LoRA enabled
# Right now: assert generator backend must be vllm, training backend must be fsdp/fsdp2
Expand Down