diff --git a/skyrl-train/docs/algorithms/custom_algorithms.rst b/skyrl-train/docs/algorithms/custom_algorithms.rst
index 09030ef79..174cf3964 100644
--- a/skyrl-train/docs/algorithms/custom_algorithms.rst
+++ b/skyrl-train/docs/algorithms/custom_algorithms.rst
@@ -48,14 +48,14 @@ Similarly, you can register custom policy loss functions:
.. code-block:: python
- from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry
+ from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry, LossMetrics
@register_policy_loss("reinforce")
def compute_reinforce_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None):
# Your custom policy loss implementation (like REINFORCE)
loss = (-log_probs * advantages).mean()
- # return loss and clip ratio
- return loss, 0.0
+ # return loss and loss metrics
+ return loss, LossMetrics(clip_ratio=0.0)
Registry Ray Distribution
~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst
index ab12315fe..740be7bb2 100644
--- a/skyrl-train/docs/configuration/config.rst
+++ b/skyrl-train/docs/configuration/config.rst
@@ -389,6 +389,45 @@ Algorithm Configuration
# dual clip parameters
clip_ratio_c: 3.0
+ # To be deprecated in favor of off_policy_correction.tis_ratio_type = "token"
+ # and "token_tis_ratio_clip_high"
+ use_tis: false
+ tis_imp_ratio_cap: -1.0
+
+ # references
+ # - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
+ # - https://fengyao.notion.site/off-policy-rl
+ off_policy_correction:
+ # type of importance sampling ratio to use for ppo loss correction
+ # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy})
+ tis_ratio_type: null # null, "token", "sequence"
+
+ # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type
+ token_tis_ratio_clip_high: 2.0
+ # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type
+ sequence_tis_ratio_clip_high: 5.0
+
+ # method of masking out sequences with cumulative importance sampling ratios outside the cap
+ # "product" masks out sequences with product of importance ratios outside the cap
+ # "geometric" masks out sequences with geometric mean of importance ratios outside the cap
+ sequence_mask_metric: null # null, "product", "geometric"
+
+ # used if sequence_mask_metric = "geometric"
+ # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch
+ geo_mask_high: 1.01
+ geo_mask_low: 0.99
+
+ # used if sequence_mask_metric = "product"
+ # values around 0.5-2.0 are recommended for "product" sequence_mask_metric
+ product_mask_high: 2.0
+ product_mask_low: 0.5
+
+ # separate from sequence_mask_metric and tis_ratio_type
+ # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio
+ # far outside an acceptable range (low and high thresholds)
+ outlier_token_is_threshold_low: 1e-4
+ outlier_token_is_threshold_high: 100
+
# clip-cov parameters (only used when policy_loss_type: "clip_cov")
clip_cov:
clip_ratio: 0.0002 # fraction of tokens to clip based on covariance
@@ -413,10 +452,6 @@ Algorithm Configuration
type: null # filter (DAPO), replace (POLARIS/WebSailor), or null
max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever
min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only)
-
- # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl
- use_tis: false
- tis_imp_ratio_cap: -1.0
# SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347)
sapo:
@@ -466,8 +501,8 @@ Algorithm Configuration
- ``algorithm.dynamic_sampling.type``: Type of dynamic sampling to use. Currently, we support ``filter`` (`DAPO `_), ``replace`` (`POLARIS `_ / `WebSailor `_), or ``null`` for no dynamic sampling.
- ``algorithm.dynamic_sampling.max_sample_batches``: Maximum number of batches to sample before stopping. Set to ``-1`` to sample forever.
- ``algorithm.dynamic_sampling.min_replace_ratio``: Minimum proportion of good samples with which to replace bad samples for ``replace`` strategy.
-- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog `_.
-- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS.
+- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog `_. This flag is to be deprecated, use ``off_policy_correction.tis_ratio_type = "token"`` instead.
+- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS. This flag is to be deprecated, use ``off_policy_correction.token_tis_ratio_clip_high`` instead.
- ``algorithm.clip_cov``: Clip-Cov parameters (only used when ``policy_loss_type`` is ``clip_cov``):
- ``clip_ratio``: Fraction of tokens to clip based on covariance values.
@@ -489,6 +524,35 @@ Algorithm Configuration
- ``tau_pos``: Temperature for gating function for tokens with positive advantages.
- ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages.
+Off Policy Correction Configuration
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+- ``algorithm.off_policy_correction``: Off policy correction configuration. See the full configuration below
+
+.. code-block:: yaml
+
+ off_policy_correction:
+ tis_ratio_type: null # null, "token", "sequence"
+ token_tis_ratio_clip_high: 2.0
+ sequence_tis_ratio_clip_high: 5.0
+ sequence_mask_metric: null # null, "product", "geometric"
+ geo_mask_high: 1.01
+ geo_mask_low: 0.99
+ product_mask_high: 2.0
+ product_mask_low: 0.5
+ outlier_token_is_threshold_low: 1e-4
+ outlier_token_is_threshold_high: 100
+
+- ``algorithm.off_policy_correction.tis_ratio_type``: Type of importance sampling ratio to use for ppo loss correction. Options include: ``null``, ``token``, ``sequence``.
+- ``algorithm.off_policy_correction.token_tis_ratio_clip_high``: Cap parameter for "token" tis_ratio_type.
+- ``algorithm.off_policy_correction.sequence_tis_ratio_clip_high``: Cap parameter for "sequence" tis_ratio_type.
+- ``algorithm.off_policy_correction.sequence_mask_metric``: Method of masking out sequences with cumulative importance sampling ratios outside the cap. Options include: ``null``, ``product``, ``geometric``.
+- ``algorithm.off_policy_correction.geo_mask_high``: High threshold for "geometric" sequence_mask_metric.
+- ``algorithm.off_policy_correction.geo_mask_low``: Low threshold for "geometric" sequence_mask_metric.
+- ``algorithm.off_policy_correction.product_mask_high``: High threshold for "product" sequence_mask_metric.
+- ``algorithm.off_policy_correction.product_mask_low``: Low threshold for "product" sequence_mask_metric.
+- ``algorithm.off_policy_correction.outlier_token_is_threshold_low``: Low threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds).
+- ``algorithm.off_policy_correction.outlier_token_is_threshold_high``: High threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds).
+
Policy Loss Formulation
~~~~~~~~~~~~~~~~~~~~~~~
@@ -502,7 +566,7 @@ It can be helpful to understand the final loss formulation to see how the differ
advantages: torch.Tensor,
config: DictConfig, # trainer.algorithm config
loss_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, LossMetrics]:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
@@ -515,7 +579,7 @@ It can be helpful to understand the final loss formulation to see how the differ
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
loss = reduce_loss(loss, loss_mask, config.loss_reduction)
- return loss, clip_ratio
+ return loss, LossMetrics(clip_ratio=clip_ratio)
Generator Configuration
diff --git a/skyrl-train/docs/examples/flash_rl.rst b/skyrl-train/docs/examples/flash_rl.rst
index a34c18ef3..210859f5f 100644
--- a/skyrl-train/docs/examples/flash_rl.rst
+++ b/skyrl-train/docs/examples/flash_rl.rst
@@ -60,13 +60,13 @@ We highlight some important training parameters configured for FlashRL from our
DATA_DIR="$HOME/data/dapo"
# TIS parameters
- USE_TIS=true
+ TIS_TYPE=token
TIS_IMP_RATIO_CAP=8.0
uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 -m examples.flash_rl.main_dapo_flashrl \
...
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
generator.sampling_params.logprobs=0 \
...
diff --git a/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py b/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py
index 6b9be8b13..b4a172034 100644
--- a/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py
+++ b/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py
@@ -9,7 +9,7 @@
from omegaconf import DictConfig
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
-from skyrl_train.utils.ppo_utils import PolicyLossRegistry
+from skyrl_train.utils.ppo_utils import PolicyLossRegistry, LossMetrics
# Example of custom policy loss: "reinforce"
@@ -27,7 +27,7 @@ def compute_reinforce_policy_loss(
loss = (-log_probs * advantages).mean()
# Return loss and dummy clip_ratio (no clipping in REINFORCE)
- return loss, 0.0
+ return loss, LossMetrics(clip_ratio=0.0)
# Register the custom policy loss
diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh
index d0aff21ab..9d16cc709 100644
--- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh
+++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh
@@ -33,6 +33,8 @@ MAX_RESPONSE_LENGTH=1024
CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt"
+TIS_TYPE=token
+TIS_IMP_RATIO_CAP=2.0
uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
@@ -53,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with v
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
- trainer.algorithm.use_tis=true \
- trainer.algorithm.tis_imp_ratio_cap=2.0 \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh
index f547788f5..5c31645bc 100644
--- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh
+++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh
@@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024
CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt"
+TIS_TYPE=token
+TIS_IMP_RATIO_CAP=2.0
+
uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
@@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
- trainer.algorithm.use_tis=true \
- trainer.algorithm.tis_imp_ratio_cap=2.0 \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh
index 7c6c7fb4a..383f87ba8 100644
--- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh
+++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh
@@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024
CKPT_PATH="$HOME/ckpts/gsm8k_32B_ckpt"
+TIS_TYPE=token
+TIS_IMP_RATIO_CAP=2.0
+
uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
@@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
- trainer.algorithm.use_tis=true \
- trainer.algorithm.tis_imp_ratio_cap=2.0 \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-32B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh
index 97ca62fde..70db7e13f 100644
--- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh
+++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh
@@ -31,6 +31,8 @@ EVAL_TOP_P=0.7
CLIP_RATIO_C=10.0
MAX_RESPONSE_LENGTH=4096
+TIS_TYPE=token
+TIS_IMP_RATIO_CAP=2.0
uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/dapo-math-17k-cleaned.parquet']" \
data.val_data="['$DATA_DIR/aime-2024-cleaned.parquet']" \
@@ -50,8 +52,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
- trainer.algorithm.use_tis=true \
- trainer.algorithm.tis_imp_ratio_cap=2.0 \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh
index bf4a033f4..83842dfc4 100644
--- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh
+++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh
@@ -33,7 +33,7 @@ MAX_RESPONSE_LENGTH=20480
MAX_PROMPT_LENGTH=2048
TIS_IMP_RATIO_CAP=8.0
-USE_TIS=true
+TIS_TYPE=token
LOGPROBS=0
CKPT_PATH="$HOME/ckpts/dapo_32b_ckpt"
@@ -57,8 +57,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-32B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/fully_async/async_run_gsm8k.sh b/skyrl-train/examples/fully_async/async_run_gsm8k.sh
index 5a4404b3b..2b2412875 100644
--- a/skyrl-train/examples/fully_async/async_run_gsm8k.sh
+++ b/skyrl-train/examples/fully_async/async_run_gsm8k.sh
@@ -28,10 +28,10 @@ set -x
: "${MAX_STALENESS_STEPS:=4}"
: "${NUM_PARALLEL_GENERATION_WORKERS:=$(( MINI_BATCH_SIZE * (MAX_STALENESS_STEPS + 1) ))}"
-USE_TIS=true
+TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0
-RUN_NAME=gsm8k-async-qwen2.5_1.5B-useTIS_${USE_TIS}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen
+RUN_NAME=gsm8k-async-qwen2.5_1.5B-TIS_TYPE_${TIS_TYPE}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen
uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async \
data.train_data="['$DATA_DIR/train.parquet']" \
@@ -39,8 +39,8 @@ uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async
trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \
trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \
trainer.algorithm.advantage_estimator="grpo" \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=false \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh
index d2af35b5b..3c10e4f18 100644
--- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh
+++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh
@@ -53,7 +53,7 @@ MEGATRON_ETP=1
# TIS parameters
TIS_IMP_RATIO_CAP=2.0
-USE_TIS=true
+TIS_TYPE=token
uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
@@ -83,8 +83,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh
index 7ae14bc4d..d9d9c90e3 100644
--- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh
+++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh
@@ -55,8 +55,13 @@ LORA_RANK=32
LORA_ALPHA=64
# TIS parameters
-TIS_IMP_RATIO_CAP=2.0
-USE_TIS=true
+TIS_IMP_RATIO_CAP=3.0
+
+# rollout correction parameters
+TIS_RATIO_TYPE="sequence"
+sequence_mask_metric="geometric"
+geo_mask_high=1.05
+geo_mask_low=0.95
uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
@@ -88,8 +93,11 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.model.lora.rank=$LORA_RANK \
trainer.policy.model.lora.alpha=$LORA_ALPHA \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_RATIO_TYPE \
+ trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.sequence_mask_metric=$sequence_mask_metric \
+ trainer.algorithm.off_policy_correction.geo_mask_high=$geo_mask_high \
+ trainer.algorithm.off_policy_correction.geo_mask_low=$geo_mask_low \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh
index 014ee567f..86a6a7ec6 100644
--- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh
+++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh
@@ -50,7 +50,7 @@ MEGATRON_ETP=null
# TIS parameters
TIS_IMP_RATIO_CAP=2.0
-USE_TIS=true
+TIS_TYPE=token
uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
@@ -80,8 +80,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh
index 25dc8e6d3..801c69ab0 100644
--- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh
+++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh
@@ -55,7 +55,7 @@ LORA_A_INIT_METHOD="kaiming"
# TIS parameters
TIS_IMP_RATIO_CAP=2.0
-USE_TIS=true
+TIS_TYPE=token
uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
@@ -85,8 +85,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.lora.rank=$LORA_RANK \
trainer.policy.model.lora.alpha=$LORA_ALPHA \
trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \
diff --git a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh
index 6c2f3a899..efe6eaf4c 100644
--- a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh
+++ b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh
@@ -38,7 +38,7 @@ LORA_A_INIT_METHOD="kaiming"
# TIS parameters
TIS_IMP_RATIO_CAP=2.0
-USE_TIS=true
+TIS_TYPE=token
uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
@@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.ref.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.ref.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.lora.rank=$LORA_RANK \
trainer.policy.model.lora.alpha=$LORA_ALPHA \
trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \
diff --git a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py
index c7030cbcc..4afa470af 100644
--- a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py
+++ b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py
@@ -10,6 +10,7 @@
register_advantage_estimator,
register_policy_loss,
reduce_loss,
+ LossMetrics,
)
from skyrl_train.training_batch import TrainingInputBatch
@@ -53,7 +54,7 @@ def compute_importance_sampling_policy_loss(
loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len)
# return loss and a dummy clip ratio value as we aren't clipping here
- return loss, 0.0
+ return loss, LossMetrics(clip_ratio=0.0)
class OnPolicyDistillationExp(BasePPOExp):
diff --git a/skyrl-train/examples/search/run_search.sh b/skyrl-train/examples/search/run_search.sh
index 1703e455e..7166f6283 100755
--- a/skyrl-train/examples/search/run_search.sh
+++ b/skyrl-train/examples/search/run_search.sh
@@ -11,7 +11,7 @@ DATA_DIR="$HOME/data/searchR1"
RUN_NAME="skyrl-search_4turns_maxgeneratelen_500-multiturn-sync-TIS_2.0"
-USE_TIS=true
+TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0
uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
@@ -23,8 +23,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
trainer.policy.optimizer_config.num_warmup_steps=94 \
trainer.algorithm.use_kl_loss=true \
trainer.algorithm.kl_loss_coef=0.001 \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/search/run_search_conversation_format.sh b/skyrl-train/examples/search/run_search_conversation_format.sh
index 9346efc13..285bafd29 100755
--- a/skyrl-train/examples/search/run_search_conversation_format.sh
+++ b/skyrl-train/examples/search/run_search_conversation_format.sh
@@ -18,7 +18,7 @@ DATA_DIR="$HOME/data/searchR1"
RUN_NAME="skyrl-search_4turns_maxgeneratelen_500"
-USE_TIS=true
+TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0
uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
@@ -30,8 +30,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
trainer.policy.optimizer_config.num_warmup_steps=94 \
trainer.algorithm.use_kl_loss=true \
trainer.algorithm.kl_loss_coef=0.001 \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh
index f33cb7e65..47b87d85c 100644
--- a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh
+++ b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh
@@ -32,7 +32,7 @@ MEGATRON_ETP=null
# TIS parameters
TIS_IMP_RATIO_CAP=2.0
-USE_TIS=true
+TIS_TYPE=token
uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
trainer.algorithm.advantage_estimator="grpo" \
@@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
trainer.policy.optimizer_config.lr=3.0e-5 \
trainer.policy_mini_batch_size=256 \
trainer.algorithm.use_kl_loss=false \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.ckpt_interval=60 \
trainer.hf_save_interval=30 \
trainer.dump_data_batch=true \
diff --git a/skyrl-train/examples/tis_correction/run_dapo_tis.sh b/skyrl-train/examples/tis_correction/run_dapo_tis.sh
index a0995dd3c..674d0aaa3 100644
--- a/skyrl-train/examples/tis_correction/run_dapo_tis.sh
+++ b/skyrl-train/examples/tis_correction/run_dapo_tis.sh
@@ -12,7 +12,7 @@ LOGGER="wandb" # change to "console" to print to stdout
# TIS parameters
TIS_IMP_RATIO_CAP=2.0
-USE_TIS=true
+TIS_TYPE=token
# returns rollout logprobs for the generated tokens; required for TIS
LOGPROBS=0
@@ -55,8 +55,8 @@ uv run --isolated --extra vllm -m examples.tis_correction.main_tis_dapo \
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
- trainer.algorithm.use_tis=$USE_TIS \
- trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
+ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
+ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml
index 1f8c1b43b..0e8b6385c 100644
--- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml
+++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml
@@ -120,9 +120,46 @@ trainer:
eps_clip_high: 0.2
# dual clip parameters
clip_ratio_c: 3.0
- # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl
+
+ # To be deprecated in favor of off_policy_correction.tis_ratio_type = "token"
+ # and "token_tis_ratio_clip_high"
tis_imp_ratio_cap: -1.0
use_tis: false
+
+ # references
+ # - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
+ # - https://fengyao.notion.site/off-policy-rl
+ off_policy_correction:
+ # type of importance sampling ratio to use for ppo loss correction
+ # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy})
+ tis_ratio_type: null # null, "token", "sequence"
+
+ # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type
+ token_tis_ratio_clip_high: 2.0
+ # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type
+ sequence_tis_ratio_clip_high: 5.0
+
+ # method of masking out sequences with cumulative importance sampling ratios outside the cap
+ # "product" masks out sequences with product of importance ratios outside the cap
+ # "geometric" masks out sequences with geometric mean of importance ratios outside the cap
+ sequence_mask_metric: null # null, "product", "geometric"
+
+ # used if sequence_mask_metric = "geometric"
+ # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch
+ geo_mask_high: 1.01
+ geo_mask_low: 0.99
+
+ # used if sequence_mask_metric = "product"
+ # values around 0.5-2.0 are recommended for "product" sequence_mask_metric
+ product_mask_high: 2.0
+ product_mask_low: 0.5
+
+ # separate from sequence_mask_metric and tis_ratio_type
+ # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio
+ # far outside an acceptable range (low and high thresholds)
+ outlier_token_is_threshold_low: 1e-4
+ outlier_token_is_threshold_high: 100
+
# SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347)
sapo:
tau_pos: 1.0
diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py
index acceccb45..fb8269e97 100644
--- a/skyrl-train/skyrl_train/distributed/strategy.py
+++ b/skyrl-train/skyrl_train/distributed/strategy.py
@@ -69,12 +69,9 @@ def get_rank(self) -> int:
def all_reduce(self, data: DataT, op="mean") -> DataT:
"""Perform all_reduce across all processes"""
- assert op in ("mean", "max", "sum")
+ assert op in ("mean", "max", "sum", "min")
if isinstance(data, dict):
- ret = {}
- for k, v in data.items():
- ret[k] = self.all_reduce(v, op)
- return ret
+ return {k: self.all_reduce(v, op) for k, v in data.items()}
else:
is_tensor = True
if not isinstance(data, torch.Tensor):
@@ -86,7 +83,13 @@ def all_reduce(self, data: DataT, op="mean") -> DataT:
data = data.to(torch.cuda.current_device())
if op == "mean":
data /= self.world_size
- dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM)
+ dist.all_reduce(data, op=dist.ReduceOp.SUM)
+ elif op == "max":
+ dist.all_reduce(data, op=dist.ReduceOp.MAX)
+ elif op == "min":
+ dist.all_reduce(data, op=dist.ReduceOp.MIN)
+ elif op == "sum":
+ dist.all_reduce(data, op=dist.ReduceOp.SUM)
if is_cpu_tensor:
data = data.cpu()
return data.item() if not is_tensor else data
diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py
index faa578f65..9b27112c4 100644
--- a/skyrl-train/skyrl_train/trainer.py
+++ b/skyrl-train/skyrl_train/trainer.py
@@ -579,12 +579,17 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
loss_masks,
logprobs,
)
- # sanity check for tis
- if self.cfg.trainer.algorithm.use_tis:
+
+ # sanity check for off_policy_correction
+ off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
+ tis_ratio_type = off_policy_correction.tis_ratio_type
+ sequence_mask_metric = off_policy_correction.sequence_mask_metric
+ if tis_ratio_type is not None or sequence_mask_metric is not None:
assert (
rollout_logprobs_tensor is not None
- ), "expected non-null rollout logprobs tensor with `trainer.algorithm.use_tis` as `True`"
+ ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"
+
training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py
index d814aedd7..c192e8e4e 100644
--- a/skyrl-train/skyrl_train/utils/ppo_utils.py
+++ b/skyrl-train/skyrl_train/utils/ppo_utils.py
@@ -19,7 +19,7 @@
from collections import defaultdict
from enum import StrEnum
from functools import wraps
-from typing import Callable, List, Literal, Optional, Tuple, Union
+from typing import Callable, List, Literal, Optional, Tuple, Union, TypedDict, NotRequired
import numpy as np
import ray
@@ -196,6 +196,10 @@ def ppo_critic_loss(
return 0.5 * loss, clipfrac
+class LossMetrics(TypedDict, total=False):
+ clip_ratio: NotRequired[float]
+
+
# Shared registry actor class for both policy loss and advantage estimator registries
@ray.remote
class RegistryActor:
@@ -552,6 +556,252 @@ def _safe_exp_delta(delta: torch.Tensor, clip: float = 20.0, out_dtype=None) ->
return y.to(out_dtype or delta.dtype)
+def compute_tis_ratio(
+ old_log_probs: torch.Tensor,
+ rollout_logprobs: torch.Tensor,
+ loss_mask: torch.Tensor,
+ tis_ratio_type: str,
+ off_policy_correction: DictConfig,
+) -> Tuple[torch.Tensor, dict]:
+ """
+ Compute truncated importance sampling (TIS) ratio for off policy correction.
+
+ Args:
+ old_log_probs: Log probabilities from the old policy (before update).
+ rollout_logprobs: Log probabilities from the rollout policy.
+ loss_mask: Mask indicating valid tokens.
+ tis_ratio_type: Type of TIS ratio ("token" or "sequence").
+ off_policy_correction: Off-policy correction config containing cap values.
+
+ Returns:
+ Tuple of (tis_ratio, metrics):
+ - tis_ratio: Tensor (float) to multiply with the loss
+ - metrics: Dict with masking statistics
+
+ Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
+ """
+ # Compute token-level importance ratio: pi_old / pi_rollout
+ # In log space: old_log_probs - rollout_logprobs
+ token_tis_log_ratio = old_log_probs - rollout_logprobs
+ token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)
+
+ metrics = {}
+ if tis_ratio_type == "token":
+ token_tis_ratio_cap = off_policy_correction.token_tis_ratio_clip_high
+ # Compute proportion of tokens capped
+ tokens_capped = (token_tis_ratio > token_tis_ratio_cap) & (loss_mask > 0)
+ total_tokens = (loss_mask > 0).sum()
+ metrics["tis_token_clip_high_ratio"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item()
+ return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap).detach(), metrics
+ elif tis_ratio_type == "sequence":
+ # Compute sequence-level importance ratio as product of token ratios (sum of log ratios)
+ seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True)
+ seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)
+ seq_tis_ratio_cap = off_policy_correction.sequence_tis_ratio_clip_high
+ # Compute proportion of sequences capped
+ num_sequences = seq_tis_ratio.shape[0]
+ seqs_capped = (seq_tis_ratio > seq_tis_ratio_cap).sum()
+ metrics["tis_seq_clip_high_ratio"] = (seqs_capped / num_sequences).detach().item()
+ return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap).detach(), metrics
+ else:
+ raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}")
+
+
+def compute_outlier_token_mask(
+ old_log_probs: torch.Tensor,
+ rollout_logprobs: torch.Tensor,
+ loss_mask: torch.Tensor,
+ off_policy_correction: DictConfig,
+) -> Tuple[torch.Tensor, dict]:
+ """
+ Compute outlier token mask that masks out sequences with any token having
+ importance ratio outside acceptable bounds.
+
+ This is applied independently of TIS ratio type or sequence mask type,
+ whenever off policy correction is enabled.
+
+ Args:
+ old_log_probs: Log probabilities from the old policy (before update).
+ rollout_logprobs: Log probabilities from the rollout policy.
+ loss_mask: Mask indicating valid tokens.
+ off_policy_correction: Off-policy correction config containing threshold values.
+
+ Returns:
+ Tuple of (outlier_mask, metrics):
+ - outlier_mask: Tensor (bool) to mask out sequences with any token having importance ratio outside acceptable bounds
+ - metrics: Dict with masking statistics
+ """
+ metrics = {}
+ # Compute token-level importance ratio
+ token_tis_log_ratio = old_log_probs - rollout_logprobs
+ token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)
+
+ # Check per-token bounds
+ token_mask_low = off_policy_correction.outlier_token_is_threshold_low
+ token_mask_high = off_policy_correction.outlier_token_is_threshold_high
+ token_over_high = (token_tis_ratio > token_mask_high) & (loss_mask > 0)
+ token_under_low = (token_tis_ratio < token_mask_low) & (loss_mask > 0)
+ token_in_bounds = ~token_over_high & ~token_under_low
+
+ # A sequence is valid if all tokens are in bounds (considering only masked positions)
+ all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True)
+
+ # Compute metrics
+ num_sequences = float(all_tokens_valid.shape[0])
+ # Sequence has any token over high threshold
+ seq_has_over_high = token_over_high.any(dim=-1)
+ # Sequence has any token under low threshold
+ seq_has_under_low = token_under_low.any(dim=-1)
+
+ metrics["outlier_seq_masked_ratio"] = ((~all_tokens_valid.squeeze(-1)).sum() / num_sequences).detach().item()
+ metrics["outlier_seq_over_high_ratio"] = (seq_has_over_high.sum() / num_sequences).detach().item()
+ metrics["outlier_seq_under_low_ratio"] = (seq_has_under_low.sum() / num_sequences).detach().item()
+
+ return all_tokens_valid.float(), metrics
+
+
+def compute_sequence_mask(
+ old_log_probs: torch.Tensor,
+ rollout_logprobs: torch.Tensor,
+ loss_mask: torch.Tensor,
+ sequence_mask_metric: str,
+ off_policy_correction: DictConfig,
+) -> Tuple[torch.Tensor, dict]:
+ """
+ Compute sequence mask for off policy correction.
+
+ This masks out sequences with importance ratios that fall outside acceptable bounds,
+ helping to filter out off-policy samples that may destabilize training.
+
+ Args:
+ old_log_probs: Log probabilities from the old policy (before update).
+ rollout_logprobs: Log probabilities from the rollout policy.
+ loss_mask: Mask indicating valid tokens.
+ sequence_mask_metric: Metric to use for sequence masking ("geometric" or "product").
+ off_policy_correction: Off-policy correction config containing cap values.
+
+ Returns:
+ Tuple of (sequence_mask, metrics):
+ - sequence_mask: Tensor (float) to multiply with the loss
+ - metrics: Dict with masking statistics
+ """
+ # Compute token-level importance ratio
+ token_tis_log_ratio = old_log_probs - rollout_logprobs
+ metrics = {}
+
+ if sequence_mask_metric == "geometric":
+ # Compute geometric mean of importance ratios per sequence
+ num_tokens = loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0)
+ seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True)
+ geo_mean_ratio = _safe_exp_delta(seq_tis_log_ratio / num_tokens, clip=20.0, out_dtype=old_log_probs.dtype)
+ geo_cap_high = off_policy_correction.geo_mask_high
+ geo_cap_low = off_policy_correction.geo_mask_low
+ seq_over_high = geo_mean_ratio > geo_cap_high
+ seq_under_low = geo_mean_ratio < geo_cap_low
+ geo_sequence_mask = ~seq_over_high & ~seq_under_low
+
+ num_sequences = float(geo_mean_ratio.shape[0])
+ metrics["geo_sequence_mask_masked_ratio"] = ((~geo_sequence_mask).sum() / num_sequences).detach().item()
+ metrics["geo_sequence_mask_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item()
+ metrics["geo_sequence_mask_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item()
+
+ return geo_sequence_mask.float(), metrics
+ elif sequence_mask_metric == "product":
+ # Mask out sequences with product of importance ratios outside the cap
+ seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True)
+ seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype)
+ seq_cap_high = off_policy_correction.product_mask_high
+ seq_cap_low = off_policy_correction.product_mask_low
+ seq_over_high = seq_tis_ratio > seq_cap_high
+ seq_under_low = seq_tis_ratio < seq_cap_low
+ seq_in_bounds = ~seq_over_high & ~seq_under_low
+
+ num_sequences = float(seq_tis_ratio.shape[0])
+ metrics["product_sequence_mask_masked_ratio"] = ((~seq_in_bounds).sum() / num_sequences).detach().item()
+ metrics["product_sequence_mask_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item()
+ metrics["product_sequence_mask_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item()
+
+ return seq_in_bounds.float(), metrics
+ else:
+ raise ValueError(f"Unknown sequence_mask_metric: {sequence_mask_metric}")
+
+
+def compute_off_policy_correction(
+ old_log_probs: torch.Tensor,
+ rollout_logprobs: torch.Tensor,
+ loss_mask: torch.Tensor,
+ off_policy_correction: DictConfig,
+) -> Tuple[Optional[torch.Tensor], dict, torch.Tensor]:
+ """
+ Compute TIS ratio, sequence mask, and outlier token mask for off policy correction.
+
+ This is a convenience function that combines compute_tis_ratio, compute_sequence_mask,
+ and compute_outlier_token_mask.
+
+ Args:
+ old_log_probs: Log probabilities from the old policy (before update).
+ rollout_logprobs: Log probabilities from the rollout policy.
+ loss_mask: Mask indicating valid tokens.
+ off_policy_correction: Off-policy correction config.
+
+ Returns:
+ Tuple of (tis_ratio, metrics, loss_mask):
+ - tis_ratio: Tensor (float) to multiply with the loss
+ - metrics: Dict with masking statistics
+ - loss_mask: Mask indicating valid tokens after applying off policy correction
+
+ References:
+ - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
+ - https://fengyao.notion.site/off-policy-rl
+ """
+ tis_ratio_type = off_policy_correction.tis_ratio_type
+ sequence_mask_metric = off_policy_correction.sequence_mask_metric
+
+ # Check if TIS ratio correction is enabled
+ apply_tis = tis_ratio_type is not None
+ # Check if sequence mask is enabled
+ apply_sequence_mask = sequence_mask_metric is not None
+
+ # Early return if no correction needed
+ if not apply_tis and not apply_sequence_mask:
+ return None, {}, loss_mask
+
+ is_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=old_log_probs.dtype)
+ metrics = {}
+ metrics["is_ratio_mean"] = masked_mean(is_ratio, loss_mask).mean().detach().item()
+ metrics["is_ratio_std"] = (is_ratio * loss_mask).std().detach().item()
+ metrics["is_ratio_max"] = (is_ratio * loss_mask).max().detach().item()
+ metrics["is_ratio_min"] = (is_ratio * loss_mask).min().detach().item()
+
+ # Apply outlier token mask whenever off policy correction is enabled
+ # This rejects sequences with any token having importance ratio outside acceptable bounds
+ outlier_mask, outlier_metrics = compute_outlier_token_mask(
+ old_log_probs, rollout_logprobs, loss_mask, off_policy_correction
+ )
+ loss_mask = loss_mask * outlier_mask
+ metrics.update(outlier_metrics)
+
+ # Initialize tis_ratio to None (only set if TIS is enabled)
+ tis_ratio = None
+
+ # Apply TIS ratio if enabled
+ if apply_tis:
+ tis_ratio, tis_metrics = compute_tis_ratio(
+ old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, off_policy_correction
+ )
+ metrics.update(tis_metrics)
+
+ # Apply sequence mask if enabled
+ if apply_sequence_mask:
+ sequence_mask, sequence_mask_metrics = compute_sequence_mask(
+ old_log_probs, rollout_logprobs, loss_mask, sequence_mask_metric, off_policy_correction
+ )
+ loss_mask = loss_mask * sequence_mask
+ metrics.update(sequence_mask_metrics)
+
+ return tis_ratio, metrics, loss_mask
+
+
@register_policy_loss(PolicyLossType.REGULAR)
@register_policy_loss(PolicyLossType.DUAL_CLIP)
def ppo_policy_loss(
@@ -581,17 +831,20 @@ def ppo_policy_loss(
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
- if config.use_tis:
- from loguru import logger as logger_
+ loss_metrics = LossMetrics(clip_ratio=clip_ratio)
- logger_.debug(f"Using TIS with dtype: {rollout_logprobs.dtype}")
- # Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
- tis_imp_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype)
- tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
- loss = loss * tis_imp_ratio
+ # apply off policy correction
+ off_policy_correction = config.off_policy_correction
+ if rollout_logprobs is not None:
+ tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, off_policy_correction
+ )
+ if tis_ratio is not None:
+ loss = loss * tis_ratio
+ loss_metrics.update(off_policy_correction_metrics)
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
- return loss, clip_ratio
+ return loss, loss_metrics
@register_policy_loss(PolicyLossType.SAPO)
@@ -653,13 +906,20 @@ def gate_function(x, tau):
# compute policy gradient loss
loss = -gates * advantages
+ # apply off policy correction
+ off_policy_correction = config.off_policy_correction
+ loss_metrics = LossMetrics(clip_ratio=0.0)
+ if rollout_logprobs is not None:
+ tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, off_policy_correction
+ )
+ if tis_ratio is not None:
+ loss = loss * tis_ratio
+ loss_metrics.update(off_policy_correction_metrics)
# for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
- # SAPO does not use clipping, so we set clip_ratio to 0.0 for compatibility
- clip_ratio = 0.0
-
- return loss, clip_ratio
+ return loss, loss_metrics
@register_policy_loss(PolicyLossType.GSPO)
@@ -717,9 +977,20 @@ def gspo_policy_loss(
# Compute clipping ratio for monitoring
clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item()
+ # apply off policy correction
+ loss_metrics = LossMetrics(clip_ratio=clip_ratio)
+ off_policy_correction = config.off_policy_correction
+ if rollout_logprobs is not None:
+ tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, off_policy_correction
+ )
+ if tis_ratio is not None:
+ loss = loss * tis_ratio
+ loss_metrics.update(off_policy_correction_metrics)
+
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
- return loss, clip_ratio
+ return loss, loss_metrics
@register_policy_loss(PolicyLossType.CISPO)
@@ -747,8 +1018,19 @@ def compute_policy_loss_cispo(
is_clipped = (ratio < 1 - config.cispo.cispo_eps_clip_low) | (ratio > 1 + config.cispo.cispo_eps_clip_high)
clip_ratio = masked_mean(is_clipped.float(), loss_mask).mean().detach().item()
+ # apply off policy correction
+ off_policy_correction = config.off_policy_correction
+ loss_metrics = LossMetrics(clip_ratio=clip_ratio)
+ if rollout_logprobs is not None:
+ tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, off_policy_correction
+ )
+ if tis_ratio is not None:
+ loss = loss * tis_ratio
+ loss_metrics.update(off_policy_correction_metrics)
+
loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len)
- return loss, clip_ratio
+ return loss, loss_metrics
@register_policy_loss(PolicyLossType.CLIP_COV)
@@ -808,6 +1090,7 @@ def compute_policy_loss_clip_cov(
# Apply correction mask to losses
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
+
pg_loss = reduce_loss(
loss=pg_losses,
loss_mask=loss_mask,
@@ -815,7 +1098,7 @@ def compute_policy_loss_clip_cov(
max_seq_len=config.max_seq_len,
)
- return pg_loss, clip_frac.item()
+ return pg_loss, LossMetrics(clip_ratio=clip_frac.item())
@register_policy_loss(PolicyLossType.KL_COV)
@@ -875,7 +1158,7 @@ def compute_policy_loss_kl_cov(
)
# NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0
- return pg_loss, 0.0
+ return pg_loss, LossMetrics(clip_ratio=0.0)
def reduce_loss(
diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py
index 349d5778a..8c3039c54 100644
--- a/skyrl-train/skyrl_train/utils/utils.py
+++ b/skyrl-train/skyrl_train/utils/utils.py
@@ -281,26 +281,53 @@ def validate_cfg(cfg: DictConfig):
algorithm_config.kl_estimator_type = "k3"
cfg.trainer.algorithm = algorithm_config
+ # TODO (erictang000): remove this after deprecation period
if cfg.trainer.algorithm.use_tis:
- if cfg.trainer.algorithm.tis_imp_ratio_cap <= 0:
- raise ValueError(
- f"If `trainer.algorithm.use_tis` is `True` then `cfg.trainer.algorithm.tis_imp_ratio_cap` "
- f"should be > 0, got {cfg.trainer.algorithm.tis_imp_ratio_cap }"
- )
+ logger.warning(
+ f"`trainer.algorithm.use_tis` is deprecated. Setting `trainer.algorithm.off_policy_correction` to `token` instead."
+ f"with `token_tis_ratio_clip_high`={cfg.trainer.algorithm.tis_imp_ratio_cap}"
+ )
+ cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "token"
+ cfg.trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high = cfg.trainer.algorithm.tis_imp_ratio_cap
+
+ # off_policy_correction config validation
+ off_policy_correction = cfg.trainer.algorithm.off_policy_correction
+ tis_ratio_type = off_policy_correction.tis_ratio_type
+ sequence_mask_metric = off_policy_correction.sequence_mask_metric
+
+ uses_off_policy_correction = tis_ratio_type is not None or sequence_mask_metric is not None
+
+ if uses_off_policy_correction:
+ # Validate tis_ratio_type
+ if tis_ratio_type:
+ assert tis_ratio_type in [
+ "token",
+ "sequence",
+ ], f"`tis_ratio_type` must be 'None', 'token', or 'sequence', got {tis_ratio_type}"
+
+ # Validate sequence_mask_metric
+ if sequence_mask_metric:
+ assert sequence_mask_metric in [
+ "product",
+ "geometric",
+ ], f"`sequence_mask_metric` must be 'product', or 'geometric', got {sequence_mask_metric}"
+
+ # Ensure logprobs are enabled for rollout correction
if cfg.generator.sampling_params.logprobs is None:
logger.warning(
- "`generator.sampling_params.logprobs` is `None` but `trainer.algorithm.use_tis` is `True`."
+ "`generator.sampling_params.logprobs` is `None` but off_policy_correction is enabled."
" Setting `logprobs` to `True`."
)
- # just set to 0 for better user exp
cfg.generator.sampling_params.logprobs = 0
if cfg.generator.backend == "sglang":
- raise NotImplementedError("`trainer.algorithm.use_tis` doesn't support Sglang backend, please use vLLM")
- assert cfg.trainer.algorithm.policy_loss_type in [
- "regular",
- "dual_clip",
- ], "TIS is only implemented for regular and dual_clip policy loss types"
+ raise NotImplementedError(
+ "`trainer.algorithm.off_policy_correction` doesn't support Sglang backend, please use vLLM"
+ )
+ if cfg.trainer.algorithm.policy_loss_type in ["clip_cov", "kl_cov"]:
+ raise NotImplementedError(
+ "`trainer.algorithm.off_policy_correction` doesn't support clip_cov or kl_cov policy loss types"
+ )
if cfg.trainer.policy.model.lora.rank > 0:
# LoRA enabled
diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py
index d84673e66..b3d14c65f 100644
--- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py
+++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py
@@ -219,7 +219,7 @@ def loss_func(logits, data):
action_log_probs = token_logprobs[:, -num_actions:]
# policy loss should be calculated based on the selected token logprobs
- policy_loss, clip_ratio = self.policy_loss_fn(
+ policy_loss, loss_metrics = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages,
@@ -256,9 +256,10 @@ def loss_func(logits, data):
"final_loss": loss.detach().item(),
"policy_loss": policy_loss.detach().item(),
"policy_entropy": entropy.detach().item(),
- "ppo_clip_ratio": clip_ratio,
"policy_kl": kl_loss.detach().item(),
}
+ for k, v in loss_metrics.items():
+ metrics["loss_metrics/" + k] = v
return loss, metrics
def forward_step(batch_iter, model):
diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
index 080b9bb42..e308b4526 100644
--- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
+++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
@@ -30,7 +30,7 @@
from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype
from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S
from skyrl_train.training_batch import TrainingOutputBatch
-from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
+from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics
from skyrl_train.workers.worker import (
PolicyWorkerBase,
RefWorkerBase,
@@ -583,13 +583,11 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch":
# within a DP group, metrics are already the same across all workers - we then just all reduce across
# the whole world size to get the metrics for the global micro batch
for i, metrics in enumerate(metrics_list):
- status = {
- "final_loss": metrics["final_loss"],
- "policy_loss": metrics["policy_loss"],
- "policy_lr": self.optimizer.param_groups[0]["lr"],
- "ppo_clip_ratio": metrics["ppo_clip_ratio"],
- "policy_entropy": metrics["policy_entropy"],
- }
+ status = {}
+ for k, v in metrics.items():
+ status[k] = v
+
+ status["policy_lr"] = self.optimizer.param_groups[0]["lr"]
if self.cfg.trainer.algorithm.use_kl_loss:
status["policy_kl"] = metrics["policy_kl"]
@@ -600,7 +598,8 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch":
# attach response_length
status["response_length"] = micro_buffer[i]["num_actions"]
- status = self.strategy.all_reduce(status)
+ status = all_reduce_metrics(status, self.strategy)
+
status_list.append(status)
for k, v in status.items():
all_metrics[k].append(v)
diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py
index 361a0777e..f062f9b20 100644
--- a/skyrl-train/skyrl_train/workers/worker.py
+++ b/skyrl-train/skyrl_train/workers/worker.py
@@ -32,7 +32,7 @@
from loguru import logger
from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch
from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl
-from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
+from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics
from skyrl_train.dataset.replay_buffer import Experience
from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
@@ -663,7 +663,7 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) ->
)
# loss function
# TODO: recompute advantages
- policy_loss, clip_ratio = self.policy_loss_fn(
+ policy_loss, loss_metrics = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages,
@@ -704,10 +704,11 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) ->
status = {
"final_loss": loss.item(),
"policy_loss": policy_loss.item(),
- "ppo_clip_ratio": clip_ratio,
"policy_entropy": entropy.item(),
"response_length": num_actions,
}
+ for k, v in loss_metrics.items():
+ status["loss_metrics/" + k] = v
if self.cfg.trainer.algorithm.use_kl_loss:
status["policy_kl"] = kl_loss.item()
@@ -741,7 +742,7 @@ def record_status(status: Dict[str, float]):
# for DP
# TODO (sumanthrh): this assumes all workers are data parallel.
# We assume that outputs are replicated within tp or sp group, otherwise this is not correct.
- status = self.strategy.all_reduce(status)
+ status = all_reduce_metrics(status, self.strategy)
# weighted mean for kl
# TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch.
@@ -817,7 +818,6 @@ def record_status(status: Dict[str, float]):
torch.distributed.barrier()
# not needed beyond status logging
all_metrics.pop("response_length", None)
-
status_mean = reduce_metrics(all_metrics)
status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch
diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py
index 897d032ea..ebc25f746 100644
--- a/skyrl-train/skyrl_train/workers/worker_utils.py
+++ b/skyrl-train/skyrl_train/workers/worker_utils.py
@@ -2,6 +2,7 @@
from skyrl_train.dataset.replay_buffer import Experience
from typing import List, Dict
from skyrl_train.training_batch import TrainingInputBatch
+from skyrl_train.distributed.strategy import DistributedStrategy
def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]:
@@ -12,10 +13,28 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]:
for k, v in metrics.items():
assert len(v) > 0, f"No metrics for key {k}"
assert all(isinstance(x, (int, float)) for x in v), f"Metrics for key {k} are not all numbers"
- reduced_metrics[k] = sum(v) / len(v)
+ if k.endswith("_max"):
+ reduced_metrics[k] = max(v)
+ elif k.endswith("_min"):
+ reduced_metrics[k] = min(v)
+ else:
+ reduced_metrics[k] = sum(v) / len(v)
return reduced_metrics
+def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]:
+ """All reduce metrics across all processes."""
+ min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")}
+ max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")}
+ mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics}
+ status_mean = strategy.all_reduce(mean_metrics, op="mean")
+ status_min = strategy.all_reduce(min_metrics, op="min")
+ status_max = strategy.all_reduce(max_metrics, op="max")
+ status_mean.update(status_min)
+ status_mean.update(status_max)
+ return status_mean
+
+
class BatchIterator:
"""A simple iterator to yield micro batches of data from the training batch."""
diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py
index f5904b595..f3aa2a0e8 100644
--- a/skyrl-train/tests/cpu/algorithms/test_losses.py
+++ b/skyrl-train/tests/cpu/algorithms/test_losses.py
@@ -8,7 +8,21 @@
import torch
from omegaconf import DictConfig
-from skyrl_train.utils.ppo_utils import PolicyLossRegistry, masked_mean
+from skyrl_train.utils.ppo_utils import (
+ PolicyLossRegistry,
+ masked_mean,
+ compute_tis_ratio,
+ compute_sequence_mask,
+ compute_outlier_token_mask,
+ compute_off_policy_correction,
+)
+
+NULL_OFF_POLICY_CORR = {
+ "tis_ratio_type": None,
+ "sequence_mask_metric": None,
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+}
# Adapted a good test from NeMO-RL
@@ -33,7 +47,7 @@ def test_policy_loss_dual_clip():
"policy_loss_type": "dual_clip",
"loss_reduction": "token_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -86,7 +100,7 @@ def test_policy_loss_cispo():
"policy_loss_type": "cispo",
"loss_reduction": "token_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -164,7 +178,7 @@ def test_policy_loss_reduction_modes():
"policy_loss_type": "regular",
"loss_reduction": "token_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -176,7 +190,7 @@ def test_policy_loss_reduction_modes():
"policy_loss_type": "regular",
"loss_reduction": "sequence_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -249,7 +263,7 @@ def test_policy_loss_reduction_edge_cases():
"policy_loss_type": "regular",
"loss_reduction": "token_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -261,7 +275,7 @@ def test_policy_loss_reduction_edge_cases():
"policy_loss_type": "regular",
"loss_reduction": "sequence_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -347,7 +361,7 @@ def test_gspo_importance_sampling_levels():
"policy_loss_type": "regular",
"loss_reduction": "token_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
ppo_loss_fn = PolicyLossRegistry.get("regular")
@@ -362,7 +376,7 @@ def test_gspo_importance_sampling_levels():
"policy_loss_type": "gspo",
"loss_reduction": "sequence_mean", # GSPO recommended reduction
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
gspo_loss_fn = PolicyLossRegistry.get("gspo")
@@ -469,6 +483,7 @@ def test_clip_cov_policy_loss():
"loss_reduction": "token_mean",
"max_seq_len": 4,
"clip_cov": {"clip_ratio": 0.5, "clip_cov_lb": -5.0, "clip_cov_ub": 5.0}, # Large ratio for testing
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -476,11 +491,12 @@ def test_clip_cov_policy_loss():
clip_cov_fn = PolicyLossRegistry.get("clip_cov")
# Calculate loss
- loss, clip_frac = clip_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask)
+ loss, loss_metrics = clip_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask)
+ clip_ratio = loss_metrics["clip_ratio"]
# Basic sanity checks
assert torch.isfinite(loss), "Loss should be finite"
- assert 0 <= clip_frac <= 1, f"Clip fraction should be between 0 and 1, got {clip_frac}"
+ assert 0 <= clip_ratio <= 1, f"Clip ratio should be between 0 and 1, got {clip_ratio}"
# Compare with regular PPO (should be different due to covariance correction)
regular_config = DictConfig(
@@ -490,12 +506,12 @@ def test_clip_cov_policy_loss():
"policy_loss_type": "regular",
"loss_reduction": "token_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
regular_fn = PolicyLossRegistry.get("regular")
- regular_loss, regular_clip_frac = regular_fn(log_probs, old_log_probs, advantages, regular_config, loss_mask)
+ regular_loss, regular_loss_metrics = regular_fn(log_probs, old_log_probs, advantages, regular_config, loss_mask)
# Clip-Cov should give different results due to covariance-based correction
assert not torch.allclose(
@@ -531,6 +547,7 @@ def test_kl_cov_policy_loss():
"loss_reduction": "token_mean",
"max_seq_len": 4,
"kl_cov": {"kl_cov_frac": 0.5, "ppo_kl_coef": 1.0}, # Apply KL to 50% of tokens
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -538,11 +555,11 @@ def test_kl_cov_policy_loss():
kl_cov_fn = PolicyLossRegistry.get("kl_cov")
# Calculate loss
- loss, clip_frac = kl_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask)
+ loss, loss_metrics = kl_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask)
# Basic sanity checks
assert torch.isfinite(loss), "Loss should be finite"
- assert clip_frac == 0.0, "KL-Cov should return 0.0 for clipfrac value"
+ assert loss_metrics["clip_ratio"] == 0.0, "KL-Cov should return 0.0 for clip_ratio value"
# Compare with regular PPO (should be different due to KL regularization)
regular_config = DictConfig(
@@ -552,7 +569,7 @@ def test_kl_cov_policy_loss():
"policy_loss_type": "regular",
"loss_reduction": "token_mean",
"max_seq_len": 4,
- "use_tis": False,
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
@@ -585,13 +602,14 @@ def test_sapo_policy_loss_basic():
"loss_reduction": "sequence_mean",
"max_seq_len": 4,
"sapo": {"tau_pos": 1.0, "tau_neg": 2.0},
+ "off_policy_correction": NULL_OFF_POLICY_CORR,
}
)
loss_fn = PolicyLossRegistry.get("sapo")
# Actual SAPO loss
- actual_loss, actual_clip_ratio = loss_fn(
+ actual_loss, loss_metrics = loss_fn(
log_probs=log_probs,
old_log_probs=old_log_probs,
advantages=advantages,
@@ -620,4 +638,575 @@ def gate_function(x, tau):
torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8)
# SAPO should always report clip_ratio = 0.0
- assert actual_clip_ratio == 0.0
+ assert loss_metrics["clip_ratio"] == 0.0
+
+
+# ============================================================================
+# Rollout Correction Tests
+# ============================================================================
+
+
+def test_compute_tis_ratio_token_level():
+ """Tests token-level TIS ratio computation with capping."""
+ device = "cpu"
+
+ # old_log_probs - rollout_logprobs gives the log importance ratio
+ # Token ratios: exp([0.5, -0.5, 1.0]) = [1.6487, 0.6065, 2.7183]
+ old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": "token",
+ "token_tis_ratio_clip_high": 2.0,
+ }
+ )
+
+ tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "token", config)
+
+ # Expected: [1.6487, 0.6065, 2.0] (third token capped at 2.0)
+ expected = torch.tensor([[1.6487, 0.6065, 2.0]], device=device)
+ torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4)
+ # One token out of 3 was capped
+ assert "tis_token_clip_high_ratio" in metrics
+ assert abs(metrics["tis_token_clip_high_ratio"] - 1 / 3) < 0.01
+
+
+def test_compute_tis_ratio_sequence_level():
+ """Tests sequence-level TIS ratio computation with capping."""
+ device = "cpu"
+
+ # Token log ratios: [0.5, -0.5, 1.0]
+ # Sequence log ratio (sum of masked): 0.5 + (-0.5) + 1.0 = 1.0
+ # Sequence ratio: exp(1.0) = 2.7183
+ old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": "sequence",
+ "sequence_tis_ratio_clip_high": 5.0,
+ }
+ )
+
+ tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config)
+
+ # Expected: exp(1.0) = 2.7183, shape [batch, 1] for sequence-level
+ expected = torch.tensor([[2.7183]], device=device)
+ torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4)
+ # No sequence was capped (2.7183 < 5.0)
+ assert "tis_seq_clip_high_ratio" in metrics
+ assert metrics["tis_seq_clip_high_ratio"] == 0.0
+
+
+def test_compute_tis_ratio_sequence_level_with_cap():
+ """Tests sequence-level TIS ratio capping."""
+ device = "cpu"
+
+ # Token log ratios: [1.0, 1.0, 1.0]
+ # Sequence log ratio: 3.0
+ # Sequence ratio: exp(3.0) = 20.09, should be capped at 5.0
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": "sequence",
+ "sequence_tis_ratio_clip_high": 5.0,
+ }
+ )
+
+ tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config)
+
+ # Expected: capped at 5.0, shape [batch, 1] for sequence-level
+ expected = torch.tensor([[5.0]], device=device)
+ torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4)
+ # One sequence out of 1 was capped
+ assert "tis_seq_clip_high_ratio" in metrics
+ assert metrics["tis_seq_clip_high_ratio"] == 1.0
+
+
+def test_compute_tis_ratio_with_mask():
+ """Tests that loss_mask correctly excludes tokens from sequence-level computation."""
+ device = "cpu"
+
+ # Token log ratios: [0.5, -0.5, 1.0]
+ # With mask [1, 0, 1], sequence log ratio = 0.5 + 1.0 = 1.5
+ old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device)
+ loss_mask = torch.tensor([[1.0, 0.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": "sequence",
+ "sequence_tis_ratio_clip_high": 10.0,
+ }
+ )
+
+ tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config)
+
+ # Expected: exp(1.5) = 4.4817, shape [batch, 1] for sequence-level
+ expected_val = torch.exp(torch.tensor(1.5))
+ expected = expected_val.reshape(1, 1)
+ torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4)
+ # No sequence was capped (4.4817 < 10.0)
+ assert "tis_seq_clip_high_ratio" in metrics
+ assert metrics["tis_seq_clip_high_ratio"] == 0.0
+
+
+def test_compute_sequence_mask_geometric():
+ """Tests geometric sequence mask computation."""
+ device = "cpu"
+
+ # Token log ratios: [0.1, -0.1, 0.0] -> sum = 0.0, geometric mean = exp(0/3) = 1.0
+ old_log_probs = torch.tensor([[-1.0, -1.1, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.1, -1.0, -1.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "sequence_mask_metric": "geometric",
+ "geo_mask_high": 1.1,
+ "geo_mask_low": 0.9,
+ }
+ )
+
+ sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config)
+
+ # Geometric mean ≈ 1.0, which is within [0.9, 1.1], so mask should be 1.0
+ # Shape is [batch, 1] for sequence-level mask
+ expected = torch.tensor([[1.0]], device=device)
+ torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4)
+ # No sequence was masked
+ assert metrics["geo_sequence_mask_masked_ratio"] == 0.0
+ assert metrics["geo_sequence_mask_over_high_ratio"] == 0.0
+ assert metrics["geo_sequence_mask_under_low_ratio"] == 0.0
+
+
+def test_compute_sequence_mask_geometric_rejects():
+ """Tests geometric sequence mask correctly rejects sequences outside bounds."""
+ device = "cpu"
+
+ # Token log ratios: [0.5, 0.5, 0.5] -> sum = 1.5, geometric mean = exp(1.5/3) = exp(0.5) ≈ 1.6487
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "sequence_mask_metric": "geometric",
+ "geo_mask_high": 1.1,
+ "geo_mask_low": 0.9,
+ }
+ )
+
+ sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config)
+
+ # Geometric mean ≈ 1.6487, which is outside [0.9, 1.1], so mask should be 0.0
+ # Shape is [batch, 1] for sequence-level mask
+ expected = torch.tensor([[0.0]], device=device)
+ torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4)
+ # One sequence masked, over high cap
+ assert metrics["geo_sequence_mask_masked_ratio"] == 1.0
+ assert metrics["geo_sequence_mask_over_high_ratio"] == 1.0
+ assert metrics["geo_sequence_mask_under_low_ratio"] == 0.0
+
+
+def test_compute_sequence_mask_product():
+ """Tests product sequence mask computation."""
+ device = "cpu"
+
+ # Token log ratios: [0.2, 0.1, 0.0] -> sum = 0.3, seq ratio = exp(0.3) ≈ 1.35
+ old_log_probs = torch.tensor([[-1.0, -1.1, -1.2]], device=device)
+ rollout_logprobs = torch.tensor([[-1.2, -1.2, -1.2]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "sequence_mask_metric": "product",
+ "product_mask_high": 2.0,
+ "product_mask_low": 0.5,
+ }
+ )
+
+ sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "product", config)
+
+ # Sequence ratio ≈ 1.35, which is within [0.5, 2.0]
+ # Shape is [batch, 1] for sequence-level mask
+ expected = torch.tensor([[1.0]], device=device)
+ torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4)
+ # No sequence was masked
+ assert metrics["product_sequence_mask_masked_ratio"] == 0.0
+ assert metrics["product_sequence_mask_over_high_ratio"] == 0.0
+ assert metrics["product_sequence_mask_under_low_ratio"] == 0.0
+
+
+def test_compute_sequence_mask_product_rejects_by_seq_ratio():
+ """Tests product sequence mask rejects when product ratio is out of bounds."""
+ device = "cpu"
+
+ # Token log ratios: [1.0, 1.0, 1.0] -> sum = 3.0, seq ratio = exp(3.0) ≈ 20.09
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "sequence_mask_metric": "product",
+ "product_mask_high": 2.0,
+ "product_mask_low": 0.5,
+ }
+ )
+
+ sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "product", config)
+
+ # Sequence ratio ≈ 20.09, which is outside [0.5, 2.0], so mask should be 0.0
+ # Shape is [batch, 1] for sequence-level mask
+ expected = torch.tensor([[0.0]], device=device)
+ torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4)
+ # One sequence masked, over high cap
+ assert metrics["product_sequence_mask_masked_ratio"] == 1.0
+ assert metrics["product_sequence_mask_over_high_ratio"] == 1.0
+ assert metrics["product_sequence_mask_under_low_ratio"] == 0.0
+
+
+def test_compute_outlier_token_mask_masks_by_token_bounds():
+ """Tests outlier token mask rejects when a token ratio is out of bounds."""
+ device = "cpu"
+
+ # Token log ratios: [0.0, 0.0, 5.0] -> token ratios = [1.0, 1.0, 148.4]
+ # Third token ratio 148.4 > 100.0, so should reject
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.0, -1.0, -6.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0, # This should cause masking
+ }
+ )
+
+ outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config)
+
+ # Token ratio 148.4 > 100.0, so mask should be 0.0
+ # Shape is [batch, 1] for sequence-level mask
+ expected = torch.tensor([[0.0]], device=device)
+ torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4)
+ # One sequence masked, has token over high threshold
+ assert metrics["outlier_seq_masked_ratio"] == 1.0
+ assert metrics["outlier_seq_over_high_ratio"] == 1.0
+ assert metrics["outlier_seq_under_low_ratio"] == 0.0
+
+
+def test_compute_outlier_token_mask_accepts_in_bounds():
+ """Tests outlier token mask accepts when all token ratios are in bounds."""
+ device = "cpu"
+
+ # Token log ratios: [0.5, -0.5, 0.0] -> token ratios = [1.65, 0.61, 1.0]
+ # All token ratios within [1e-4, 100.0], so should accept
+ old_log_probs = torch.tensor([[-1.0, -1.5, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+ }
+ )
+
+ outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config)
+
+ # All token ratios in bounds, so mask should be 1.0
+ # Shape is [batch, 1] for sequence-level mask
+ expected = torch.tensor([[1.0]], device=device)
+ torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4)
+ # No sequence was masked
+ assert metrics["outlier_seq_masked_ratio"] == 0.0
+ assert metrics["outlier_seq_over_high_ratio"] == 0.0
+ assert metrics["outlier_seq_under_low_ratio"] == 0.0
+
+
+def test_compute_outlier_token_mask_respects_loss_mask():
+ """Tests outlier token mask ignores out-of-bounds tokens that are masked."""
+ device = "cpu"
+
+ # Token log ratios: [0.0, 0.0, 5.0] -> token ratios = [1.0, 1.0, 148.4]
+ # Third token ratio 148.4 > 100.0, but it's masked, so should accept
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.0, -1.0, -6.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 0.0]], device=device) # Third token masked
+
+ config = DictConfig(
+ {
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+ }
+ )
+
+ outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config)
+
+ # Third token is masked, so even though ratio is out of bounds, sequence should be accepted
+ expected = torch.tensor([[1.0]], device=device)
+ torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4)
+ # No sequence was masked (the out-of-bounds token was in a masked position)
+ assert metrics["outlier_seq_masked_ratio"] == 0.0
+
+
+def test_compute_off_policy_correction_null_configs():
+ """Tests that compute_off_policy_correction returns None tis_ratio when both configs are null."""
+ device = "cpu"
+
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": None,
+ "sequence_mask_metric": None,
+ }
+ )
+
+ tis_ratio, metrics, new_loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, config
+ )
+
+ # Should return None tis_ratio (early return) and empty metrics
+ assert tis_ratio is None
+ assert metrics == {}
+
+
+def test_compute_off_policy_correction_tis_only():
+ """Tests compute_off_policy_correction with only TIS enabled."""
+ device = "cpu"
+
+ # Token log ratios: [0.5, 0.5, 0.5] -> token ratios = [1.6487, 1.6487, 1.6487]
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": "token",
+ "token_tis_ratio_clip_high": 2.0,
+ "sequence_mask_metric": None,
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+ }
+ )
+
+ tis_ratio, metrics, new_loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, config
+ )
+
+ # Expected tis_ratio: 1.6487 (no capping needed)
+ expected_tis_ratio = torch.exp(torch.tensor(0.5))
+ torch.testing.assert_close(
+ tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4
+ )
+ # Check metrics are populated
+ assert "is_ratio_mean" in metrics
+ assert "tis_token_clip_high_ratio" in metrics
+
+
+def test_compute_off_policy_correction_sequence_mask_only():
+ """Tests compute_off_policy_correction with only geometric sequence mask enabled."""
+ device = "cpu"
+
+ # Token log ratios: [0.0, 0.0, 0.0] -> geometric mean = 1.0
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": None,
+ "sequence_mask_metric": "geometric",
+ "geo_mask_high": 1.1,
+ "geo_mask_low": 0.9,
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+ }
+ )
+
+ tis_ratio, metrics, new_loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, config
+ )
+
+ # Geometric mean = 1.0, within bounds, so loss_mask unchanged
+ # tis_ratio is None since tis_ratio_type is None
+ assert tis_ratio is None
+ torch.testing.assert_close(new_loss_mask, loss_mask, rtol=1e-3, atol=1e-4)
+ # Check metrics are populated
+ assert "is_ratio_mean" in metrics
+ assert "geo_sequence_mask_masked_ratio" in metrics
+
+
+def test_compute_off_policy_correction_both_enabled():
+ """Tests compute_off_policy_correction with both TIS and geometric sequence mask enabled."""
+ device = "cpu"
+
+ # Token log ratios: [0.1, 0.1, 0.1] -> token ratios ≈ [1.105, 1.105, 1.105]
+ # Geometric mean ≈ 1.105
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.1, -1.1, -1.1]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": "token",
+ "token_tis_ratio_clip_high": 2.0,
+ "sequence_mask_metric": "geometric",
+ "geo_mask_high": 1.2,
+ "geo_mask_low": 0.8,
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+ }
+ )
+
+ tis_ratio, metrics, new_loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, config
+ )
+
+ # TIS ratio ≈ 1.105, geometric mean ≈ 1.105 (within bounds, mask=1)
+ expected_tis_ratio = torch.exp(torch.tensor(0.1))
+ torch.testing.assert_close(
+ tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4
+ )
+ # Check metrics from both TIS and sequence mask are populated
+ assert "tis_token_clip_high_ratio" in metrics
+ assert "geo_sequence_mask_masked_ratio" in metrics
+
+
+def test_compute_off_policy_correction_sequence_mask_zeros_loss():
+ """Tests that sequence mask can zero out the loss_mask entirely."""
+ device = "cpu"
+
+ # Token log ratios: [1.0, 1.0, 1.0] -> geometric mean = exp(1.0) ≈ 2.718
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "tis_ratio_type": None,
+ "sequence_mask_metric": "geometric",
+ "geo_mask_high": 1.1,
+ "geo_mask_low": 0.9,
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+ }
+ )
+
+ tis_ratio, metrics, new_loss_mask = compute_off_policy_correction(
+ old_log_probs, rollout_logprobs, loss_mask, config
+ )
+
+ # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss_mask should be zeroed
+ expected_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device)
+ torch.testing.assert_close(new_loss_mask, expected_mask, rtol=1e-3, atol=1e-4)
+ # Check that the sequence mask metrics show sequence mask happened
+ assert metrics["geo_sequence_mask_masked_ratio"] == 1.0
+
+
+def test_ppo_policy_loss_with_off_policy_correction():
+ """Integration test for PPO policy loss with rollout correction enabled."""
+ device = "cpu"
+
+ advantages = torch.tensor([[1.0, -1.0, 0.5]], device=device)
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ log_probs = torch.tensor([[-1.1, -0.9, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.05, -1.05, -1.05]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig(
+ {
+ "eps_clip_low": 0.2,
+ "eps_clip_high": 0.2,
+ "policy_loss_type": "regular",
+ "loss_reduction": "token_mean",
+ "max_seq_len": 4,
+ "off_policy_correction": {
+ "tis_ratio_type": "token",
+ "token_tis_ratio_clip_high": 2.0,
+ "sequence_mask_metric": None,
+ "outlier_token_is_threshold_low": 1e-4,
+ "outlier_token_is_threshold_high": 100.0,
+ },
+ }
+ )
+
+ loss_fn = PolicyLossRegistry.get("regular")
+
+ # Loss with rollout correction
+ loss_with_correction, _ = loss_fn(
+ log_probs=log_probs,
+ old_log_probs=old_log_probs,
+ advantages=advantages,
+ config=config,
+ loss_mask=loss_mask,
+ rollout_logprobs=rollout_logprobs,
+ )
+
+ # Loss without rollout correction
+ config_no_correction = DictConfig(
+ {
+ "eps_clip_low": 0.2,
+ "eps_clip_high": 0.2,
+ "policy_loss_type": "regular",
+ "loss_reduction": "token_mean",
+ "max_seq_len": 4,
+ "off_policy_correction": {
+ "tis_ratio_type": None,
+ "sequence_mask_metric": None,
+ },
+ }
+ )
+
+ loss_without_correction, _ = loss_fn(
+ log_probs=log_probs,
+ old_log_probs=old_log_probs,
+ advantages=advantages,
+ config=config_no_correction,
+ loss_mask=loss_mask,
+ rollout_logprobs=rollout_logprobs,
+ )
+
+ # TIS correction should modify the loss
+ assert not torch.allclose(loss_with_correction, loss_without_correction, rtol=1e-3), (
+ f"Rollout correction should change the loss: "
+ f"with={loss_with_correction:.6f} vs without={loss_without_correction:.6f}"
+ )
+
+
+def test_compute_tis_ratio_invalid_type():
+ """Tests that compute_tis_ratio raises error for invalid tis_ratio_type."""
+ device = "cpu"
+
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig({"tis_ratio_type": "invalid"})
+
+ with pytest.raises(ValueError, match="Unknown tis_ratio_type"):
+ compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "invalid", config)
+
+
+def test_compute_sequence_mask_invalid_type():
+ """Tests that compute_sequence_mask raises error for invalid sequence_mask_metric."""
+ device = "cpu"
+
+ old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
+ rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device)
+ loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device)
+
+ config = DictConfig({"sequence_mask_metric": "invalid"})
+
+ with pytest.raises(ValueError, match="Unknown sequence_mask_metric"):
+ compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "invalid", config)
diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py
index 599f8a3f4..1d8924eb4 100644
--- a/skyrl-train/tests/cpu/test_trainer.py
+++ b/skyrl-train/tests/cpu/test_trainer.py
@@ -550,7 +550,12 @@ def create_test_worker(worker_class):
def mock_policy_forward_backward(experience, microbatch_weight):
policy_forward_backward_calls.append({"microbatch_weight": microbatch_weight})
- return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length}
+ return {
+ "policy_loss": 0.5,
+ "loss_metrics/clip_ratio": 0.1,
+ "policy_entropy": 2.0,
+ "response_length": response_length,
+ }
policy_worker.forward_backward = mock_policy_forward_backward
policy_worker.optim_step = MagicMock(return_value=None)
diff --git a/skyrl-train/tests/cpu/utils/test_ppo_utils.py b/skyrl-train/tests/cpu/utils/test_ppo_utils.py
index fb69ce15e..f9744790c 100644
--- a/skyrl-train/tests/cpu/utils/test_ppo_utils.py
+++ b/skyrl-train/tests/cpu/utils/test_ppo_utils.py
@@ -17,6 +17,7 @@
AdvantageEstimatorRegistry,
register_advantage_estimator,
PolicyLossRegistry,
+ LossMetrics,
register_policy_loss,
compute_reinforce_plus_plus_outcome_advantage,
compute_rloo_outcome_advantage,
@@ -376,7 +377,7 @@ def test_policy_loss_registry_specific():
@register_policy_loss("test_policy_decorator")
def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None):
- return torch.tensor(1.5), 0.3
+ return torch.tensor(1.5), LossMetrics(clip_ratio=0.3)
# Test decorator worked
assert "test_policy_decorator" in PolicyLossRegistry.list_available()
@@ -385,14 +386,14 @@ def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mas
# Test function execution
config = DictConfig({"policy_loss_type": "test_policy_decorator"})
- loss, clip_ratio = retrieved(
+ loss, loss_metrics = retrieved(
log_probs=torch.tensor([[0.1]]),
old_log_probs=torch.tensor([[0.2]]),
advantages=torch.tensor([[1.0]]),
config=config,
)
assert loss.item() == 1.5
- assert clip_ratio == 0.3
+ assert loss_metrics["clip_ratio"] == 0.3
# Test error message includes "Policy loss"
with pytest.raises(ValueError, match="Unknown policy loss"):
@@ -413,10 +414,10 @@ def test_registry_cross_ray_process():
# Create test functions
def test_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None):
- return torch.tensor(2.0), 0.5
+ return torch.tensor(2.0), LossMetrics(clip_ratio=0.5)
def test_policy_loss_2(log_probs, old_log_probs, advantages, config, loss_mask=None):
- return torch.tensor(3.0), 0.6
+ return torch.tensor(3.0), LossMetrics(clip_ratio=0.6)
def test_advantage_estimator(**kwargs):
rewards = kwargs["token_level_rewards"]
@@ -432,7 +433,7 @@ def test_ray_registry_access():
policy_loss = PolicyLossRegistry.get("cross_process_test")
adv_estimator = AdvantageEstimatorRegistry.get("cross_process_adv_test")
- loss, clip_ratio = policy_loss(
+ loss, loss_metrics = policy_loss(
log_probs=torch.tensor([[0.1]]),
old_log_probs=torch.tensor([[0.2]]),
advantages=torch.tensor([[1.0]]),
@@ -444,25 +445,25 @@ def test_ray_registry_access():
response_mask=torch.tensor([[1.0, 1.0]]),
index=np.array(["0", "0"]),
)
- return loss, clip_ratio, adv, ret
+ return loss, loss_metrics, adv, ret
# Run Ray task
- loss, clip_ratio, adv, ret = ray.get(test_ray_registry_access.remote())
+ loss, loss_metrics, adv, ret = ray.get(test_ray_registry_access.remote())
assert loss.item() == 2.0
- assert clip_ratio == 0.5
+ assert loss_metrics["clip_ratio"] == 0.5
assert adv.shape == torch.Size([1, 2])
assert ret.shape == torch.Size([1, 2])
# test that registration works after ray init as well
PolicyLossRegistry.register("cross_process_test_2", test_policy_loss_2)
- loss_2, clip_ratio_2 = PolicyLossRegistry.get("cross_process_test_2")(
+ loss_2, loss_metrics_2 = PolicyLossRegistry.get("cross_process_test_2")(
log_probs=torch.tensor([[0.1]]),
old_log_probs=torch.tensor([[0.2]]),
advantages=torch.tensor([[1.0]]),
config=DictConfig({"policy_loss_type": "cross_process_test_2"}),
)
assert loss_2.item() == 3.0
- assert clip_ratio_2 == 0.6
+ assert loss_metrics_2["clip_ratio"] == 0.6
finally:
PolicyLossRegistry.reset()
AdvantageEstimatorRegistry.reset()
diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py
index 592a0518b..028b2d0bd 100644
--- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py
+++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py
@@ -472,6 +472,11 @@ async def test_megatron_train(
)
transformer_config_kwargs["num_layers"] = 2
cfg.trainer.policy.megatron_config.transformer_config_kwargs = transformer_config_kwargs
+ # test off policy correction config propagates correctly
+ cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "sequence"
+ cfg.trainer.algorithm.off_policy_correction.sequence_mask_metric = "geometric"
+ cfg.trainer.algorithm.off_policy_correction.geo_mask_high = 1.02
+ cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98
# set batch sizes correctly
cfg.trainer.train_batch_size = gpus_per_node
@@ -501,7 +506,7 @@ async def test_megatron_train(
assert isinstance(result, dict), "Result should be a dictionary of training stats"
assert "policy_loss" in result
assert "policy_lr" in result
- assert "ppo_clip_ratio" in result
+ assert "loss_metrics/clip_ratio" in result
assert "policy_entropy" in result
for k, v in result.items():
assert isinstance(v, (int, float)), f"{k} should be an int or float"
@@ -539,7 +544,22 @@ async def test_megatron_train(
print("\n\n")
print("fsdp results: ", results_fsdp[0])
- keys_to_compare = ["policy_loss", "policy_lr", "ppo_clip_ratio", "policy_entropy", "policy_kl", "final_loss"]
+ keys_to_compare = [
+ "policy_loss",
+ "policy_lr",
+ "loss_metrics/clip_ratio",
+ "policy_entropy",
+ "policy_kl",
+ "final_loss",
+ ]
+ if ep > 1:
+ keys_to_compare.extend(
+ [
+ "loss_metrics/is_ratio_mean",
+ "loss_metrics/outlier_seq_masked_ratio",
+ "loss_metrics/geo_sequence_mask_masked_ratio",
+ ]
+ )
for i, result in enumerate(results_fsdp):
for k in keys_to_compare:
if k == "policy_entropy":
@@ -605,7 +625,7 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node)
assert isinstance(result, dict), "Result should be a dictionary of training stats"
assert "policy_loss" in result
assert "policy_lr" in result
- assert "ppo_clip_ratio" in result
+ assert "loss_metrics/clip_ratio" in result
assert "policy_entropy" in result
for k, v in result.items():
assert isinstance(v, (int, float)), f"{k} should be an int or float"
@@ -641,7 +661,14 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node)
print("\n\n")
print("megatron results dp: ", results_megatron_dp)
- keys_to_compare = ["policy_loss", "policy_lr", "ppo_clip_ratio", "policy_entropy", "policy_kl", "raw_grad_norm"]
+ keys_to_compare = [
+ "policy_loss",
+ "policy_lr",
+ "loss_metrics/clip_ratio",
+ "policy_entropy",
+ "policy_kl",
+ "raw_grad_norm",
+ ]
for i, result in enumerate(results_megatron_dp):
for k in keys_to_compare:
assert isinstance(result[k], (int, float)), f"{k} should be an int or float"
diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py
index c9880a017..429ff5ede 100644
--- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py
+++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py
@@ -48,6 +48,11 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_
cfg.trainer.algorithm.use_kl_loss = True
cfg.trainer.algorithm.kl_loss_coef = 0.001
+ cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "sequence"
+ cfg.trainer.algorithm.off_policy_correction.sequence_mask_metric = "geometric"
+ cfg.trainer.algorithm.off_policy_correction.geo_mask_high = 1.02
+ cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98
+
actor_group = init_worker_with_type(
"policy",
shared_pg=None,
@@ -74,7 +79,10 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_
"policy_loss",
"policy_update_steps",
"policy_lr",
- "ppo_clip_ratio",
+ "loss_metrics/clip_ratio",
+ "loss_metrics/is_ratio_mean",
+ "loss_metrics/outlier_seq_masked_ratio",
+ "loss_metrics/geo_sequence_mask_masked_ratio",
"policy_entropy",
"final_loss",
]
diff --git a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py
index e81103434..f72f87dfa 100644
--- a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py
+++ b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py
@@ -73,7 +73,7 @@ async def test_policy_forward_backward_and_optim_step(ray_init_fixture, cfg, pac
for result in results:
assert isinstance(result, dict), "Result should be a dictionary of training stats"
assert "policy_loss" in result
- assert "ppo_clip_ratio" in result
+ assert "loss_metrics/clip_ratio" in result
assert "policy_entropy" in result
for k, v in result.items():
assert isinstance(v, (int, float)), f"{k} should be an int or float"
diff --git a/skyrl-train/tests/gpu/utils.py b/skyrl-train/tests/gpu/utils.py
index d819604f7..0c1c3097d 100644
--- a/skyrl-train/tests/gpu/utils.py
+++ b/skyrl-train/tests/gpu/utils.py
@@ -77,6 +77,7 @@ def make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4) -> Traini
"advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"),
"loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"),
"response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"),
+ "rollout_logprobs": 0.2 * torch.ones((batch_size, num_actions), device="cpu"),
}
)
data.metadata = {"response_length": num_actions}