diff --git a/skyrl-train/docs/algorithms/custom_algorithms.rst b/skyrl-train/docs/algorithms/custom_algorithms.rst index 09030ef79..174cf3964 100644 --- a/skyrl-train/docs/algorithms/custom_algorithms.rst +++ b/skyrl-train/docs/algorithms/custom_algorithms.rst @@ -48,14 +48,14 @@ Similarly, you can register custom policy loss functions: .. code-block:: python - from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry + from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry, LossMetrics @register_policy_loss("reinforce") def compute_reinforce_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None): # Your custom policy loss implementation (like REINFORCE) loss = (-log_probs * advantages).mean() - # return loss and clip ratio - return loss, 0.0 + # return loss and loss metrics + return loss, LossMetrics(clip_ratio=0.0) Registry Ray Distribution ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index ab12315fe..740be7bb2 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -389,6 +389,45 @@ Algorithm Configuration # dual clip parameters clip_ratio_c: 3.0 + # To be deprecated in favor of off_policy_correction.tis_ratio_type = "token" + # and "token_tis_ratio_clip_high" + use_tis: false + tis_imp_ratio_cap: -1.0 + + # references + # - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + # - https://fengyao.notion.site/off-policy-rl + off_policy_correction: + # type of importance sampling ratio to use for ppo loss correction + # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy}) + tis_ratio_type: null # null, "token", "sequence" + + # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type + token_tis_ratio_clip_high: 2.0 + # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type + sequence_tis_ratio_clip_high: 5.0 + + # method of masking out sequences with cumulative importance sampling ratios outside the cap + # "product" masks out sequences with product of importance ratios outside the cap + # "geometric" masks out sequences with geometric mean of importance ratios outside the cap + sequence_mask_metric: null # null, "product", "geometric" + + # used if sequence_mask_metric = "geometric" + # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch + geo_mask_high: 1.01 + geo_mask_low: 0.99 + + # used if sequence_mask_metric = "product" + # values around 0.5-2.0 are recommended for "product" sequence_mask_metric + product_mask_high: 2.0 + product_mask_low: 0.5 + + # separate from sequence_mask_metric and tis_ratio_type + # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio + # far outside an acceptable range (low and high thresholds) + outlier_token_is_threshold_low: 1e-4 + outlier_token_is_threshold_high: 100 + # clip-cov parameters (only used when policy_loss_type: "clip_cov") clip_cov: clip_ratio: 0.0002 # fraction of tokens to clip based on covariance @@ -413,10 +452,6 @@ Algorithm Configuration type: null # filter (DAPO), replace (POLARIS/WebSailor), or null max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only) - - # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl - use_tis: false - tis_imp_ratio_cap: -1.0 # SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347) sapo: @@ -466,8 +501,8 @@ Algorithm Configuration - ``algorithm.dynamic_sampling.type``: Type of dynamic sampling to use. Currently, we support ``filter`` (`DAPO `_), ``replace`` (`POLARIS `_ / `WebSailor `_), or ``null`` for no dynamic sampling. - ``algorithm.dynamic_sampling.max_sample_batches``: Maximum number of batches to sample before stopping. Set to ``-1`` to sample forever. - ``algorithm.dynamic_sampling.min_replace_ratio``: Minimum proportion of good samples with which to replace bad samples for ``replace`` strategy. -- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog `_. -- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS. +- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog `_. This flag is to be deprecated, use ``off_policy_correction.tis_ratio_type = "token"`` instead. +- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS. This flag is to be deprecated, use ``off_policy_correction.token_tis_ratio_clip_high`` instead. - ``algorithm.clip_cov``: Clip-Cov parameters (only used when ``policy_loss_type`` is ``clip_cov``): - ``clip_ratio``: Fraction of tokens to clip based on covariance values. @@ -489,6 +524,35 @@ Algorithm Configuration - ``tau_pos``: Temperature for gating function for tokens with positive advantages. - ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages. +Off Policy Correction Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- ``algorithm.off_policy_correction``: Off policy correction configuration. See the full configuration below + +.. code-block:: yaml + + off_policy_correction: + tis_ratio_type: null # null, "token", "sequence" + token_tis_ratio_clip_high: 2.0 + sequence_tis_ratio_clip_high: 5.0 + sequence_mask_metric: null # null, "product", "geometric" + geo_mask_high: 1.01 + geo_mask_low: 0.99 + product_mask_high: 2.0 + product_mask_low: 0.5 + outlier_token_is_threshold_low: 1e-4 + outlier_token_is_threshold_high: 100 + +- ``algorithm.off_policy_correction.tis_ratio_type``: Type of importance sampling ratio to use for ppo loss correction. Options include: ``null``, ``token``, ``sequence``. +- ``algorithm.off_policy_correction.token_tis_ratio_clip_high``: Cap parameter for "token" tis_ratio_type. +- ``algorithm.off_policy_correction.sequence_tis_ratio_clip_high``: Cap parameter for "sequence" tis_ratio_type. +- ``algorithm.off_policy_correction.sequence_mask_metric``: Method of masking out sequences with cumulative importance sampling ratios outside the cap. Options include: ``null``, ``product``, ``geometric``. +- ``algorithm.off_policy_correction.geo_mask_high``: High threshold for "geometric" sequence_mask_metric. +- ``algorithm.off_policy_correction.geo_mask_low``: Low threshold for "geometric" sequence_mask_metric. +- ``algorithm.off_policy_correction.product_mask_high``: High threshold for "product" sequence_mask_metric. +- ``algorithm.off_policy_correction.product_mask_low``: Low threshold for "product" sequence_mask_metric. +- ``algorithm.off_policy_correction.outlier_token_is_threshold_low``: Low threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds). +- ``algorithm.off_policy_correction.outlier_token_is_threshold_high``: High threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds). + Policy Loss Formulation ~~~~~~~~~~~~~~~~~~~~~~~ @@ -502,7 +566,7 @@ It can be helpful to understand the final loss formulation to see how the differ advantages: torch.Tensor, config: DictConfig, # trainer.algorithm config loss_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, LossMetrics]: ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages @@ -515,7 +579,7 @@ It can be helpful to understand the final loss formulation to see how the differ clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) loss = reduce_loss(loss, loss_mask, config.loss_reduction) - return loss, clip_ratio + return loss, LossMetrics(clip_ratio=clip_ratio) Generator Configuration diff --git a/skyrl-train/docs/examples/flash_rl.rst b/skyrl-train/docs/examples/flash_rl.rst index a34c18ef3..210859f5f 100644 --- a/skyrl-train/docs/examples/flash_rl.rst +++ b/skyrl-train/docs/examples/flash_rl.rst @@ -60,13 +60,13 @@ We highlight some important training parameters configured for FlashRL from our DATA_DIR="$HOME/data/dapo" # TIS parameters - USE_TIS=true + TIS_TYPE=token TIS_IMP_RATIO_CAP=8.0 uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 -m examples.flash_rl.main_dapo_flashrl \ ... - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ generator.sampling_params.logprobs=0 \ ... diff --git a/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py b/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py index 6b9be8b13..b4a172034 100644 --- a/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py +++ b/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig from skyrl_train.utils import initialize_ray from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg -from skyrl_train.utils.ppo_utils import PolicyLossRegistry +from skyrl_train.utils.ppo_utils import PolicyLossRegistry, LossMetrics # Example of custom policy loss: "reinforce" @@ -27,7 +27,7 @@ def compute_reinforce_policy_loss( loss = (-log_probs * advantages).mean() # Return loss and dummy clip_ratio (no clipping in REINFORCE) - return loss, 0.0 + return loss, LossMetrics(clip_ratio=0.0) # Register the custom policy loss diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh index d0aff21ab..9d16cc709 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh @@ -33,6 +33,8 @@ MAX_RESPONSE_LENGTH=1024 CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt" +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/train.parquet']" \ @@ -53,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with v generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh index f547788f5..5c31645bc 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024 CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt" +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 + uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ @@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh index 7c6c7fb4a..383f87ba8 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024 CKPT_PATH="$HOME/ckpts/gsm8k_32B_ckpt" +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 + uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ @@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-32B" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh index 97ca62fde..70db7e13f 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh @@ -31,6 +31,8 @@ EVAL_TOP_P=0.7 CLIP_RATIO_C=10.0 MAX_RESPONSE_LENGTH=4096 +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/dapo-math-17k-cleaned.parquet']" \ data.val_data="['$DATA_DIR/aime-2024-cleaned.parquet']" \ @@ -50,8 +52,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 -- trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh index bf4a033f4..83842dfc4 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh @@ -33,7 +33,7 @@ MAX_RESPONSE_LENGTH=20480 MAX_PROMPT_LENGTH=2048 TIS_IMP_RATIO_CAP=8.0 -USE_TIS=true +TIS_TYPE=token LOGPROBS=0 CKPT_PATH="$HOME/ckpts/dapo_32b_ckpt" @@ -57,8 +57,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-32B" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/fully_async/async_run_gsm8k.sh b/skyrl-train/examples/fully_async/async_run_gsm8k.sh index 5a4404b3b..2b2412875 100644 --- a/skyrl-train/examples/fully_async/async_run_gsm8k.sh +++ b/skyrl-train/examples/fully_async/async_run_gsm8k.sh @@ -28,10 +28,10 @@ set -x : "${MAX_STALENESS_STEPS:=4}" : "${NUM_PARALLEL_GENERATION_WORKERS:=$(( MINI_BATCH_SIZE * (MAX_STALENESS_STEPS + 1) ))}" -USE_TIS=true +TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 -RUN_NAME=gsm8k-async-qwen2.5_1.5B-useTIS_${USE_TIS}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen +RUN_NAME=gsm8k-async-qwen2.5_1.5B-TIS_TYPE_${TIS_TYPE}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async \ data.train_data="['$DATA_DIR/train.parquet']" \ @@ -39,8 +39,8 @@ uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \ trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \ trainer.algorithm.advantage_estimator="grpo" \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=false \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh index d2af35b5b..3c10e4f18 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh @@ -53,7 +53,7 @@ MEGATRON_ETP=1 # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -83,8 +83,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh index 7ae14bc4d..d9d9c90e3 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh @@ -55,8 +55,13 @@ LORA_RANK=32 LORA_ALPHA=64 # TIS parameters -TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_IMP_RATIO_CAP=3.0 + +# rollout correction parameters +TIS_RATIO_TYPE="sequence" +sequence_mask_metric="geometric" +geo_mask_high=1.05 +geo_mask_low=0.95 uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -88,8 +93,11 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_RATIO_TYPE \ + trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.sequence_mask_metric=$sequence_mask_metric \ + trainer.algorithm.off_policy_correction.geo_mask_high=$geo_mask_high \ + trainer.algorithm.off_policy_correction.geo_mask_low=$geo_mask_low \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh index 014ee567f..86a6a7ec6 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh @@ -50,7 +50,7 @@ MEGATRON_ETP=null # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -80,8 +80,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh index 25dc8e6d3..801c69ab0 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh @@ -55,7 +55,7 @@ LORA_A_INIT_METHOD="kaiming" # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -85,8 +85,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \ diff --git a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh index 6c2f3a899..efe6eaf4c 100644 --- a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh +++ b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh @@ -38,7 +38,7 @@ LORA_A_INIT_METHOD="kaiming" # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ @@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.ref.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.ref.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \ diff --git a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py index c7030cbcc..4afa470af 100644 --- a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py +++ b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py @@ -10,6 +10,7 @@ register_advantage_estimator, register_policy_loss, reduce_loss, + LossMetrics, ) from skyrl_train.training_batch import TrainingInputBatch @@ -53,7 +54,7 @@ def compute_importance_sampling_policy_loss( loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len) # return loss and a dummy clip ratio value as we aren't clipping here - return loss, 0.0 + return loss, LossMetrics(clip_ratio=0.0) class OnPolicyDistillationExp(BasePPOExp): diff --git a/skyrl-train/examples/search/run_search.sh b/skyrl-train/examples/search/run_search.sh index 1703e455e..7166f6283 100755 --- a/skyrl-train/examples/search/run_search.sh +++ b/skyrl-train/examples/search/run_search.sh @@ -11,7 +11,7 @@ DATA_DIR="$HOME/data/searchR1" RUN_NAME="skyrl-search_4turns_maxgeneratelen_500-multiturn-sync-TIS_2.0" -USE_TIS=true +TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ @@ -23,8 +23,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.num_warmup_steps=94 \ trainer.algorithm.use_kl_loss=true \ trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/search/run_search_conversation_format.sh b/skyrl-train/examples/search/run_search_conversation_format.sh index 9346efc13..285bafd29 100755 --- a/skyrl-train/examples/search/run_search_conversation_format.sh +++ b/skyrl-train/examples/search/run_search_conversation_format.sh @@ -18,7 +18,7 @@ DATA_DIR="$HOME/data/searchR1" RUN_NAME="skyrl-search_4turns_maxgeneratelen_500" -USE_TIS=true +TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ @@ -30,8 +30,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.num_warmup_steps=94 \ trainer.algorithm.use_kl_loss=true \ trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh index f33cb7e65..47b87d85c 100644 --- a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh +++ b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh @@ -32,7 +32,7 @@ MEGATRON_ETP=null # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.algorithm.advantage_estimator="grpo" \ @@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.lr=3.0e-5 \ trainer.policy_mini_batch_size=256 \ trainer.algorithm.use_kl_loss=false \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.ckpt_interval=60 \ trainer.hf_save_interval=30 \ trainer.dump_data_batch=true \ diff --git a/skyrl-train/examples/tis_correction/run_dapo_tis.sh b/skyrl-train/examples/tis_correction/run_dapo_tis.sh index a0995dd3c..674d0aaa3 100644 --- a/skyrl-train/examples/tis_correction/run_dapo_tis.sh +++ b/skyrl-train/examples/tis_correction/run_dapo_tis.sh @@ -12,7 +12,7 @@ LOGGER="wandb" # change to "console" to print to stdout # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token # returns rollout logprobs for the generated tokens; required for TIS LOGPROBS=0 @@ -55,8 +55,8 @@ uv run --isolated --extra vllm -m examples.tis_correction.main_tis_dapo \ generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 1f8c1b43b..0e8b6385c 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -120,9 +120,46 @@ trainer: eps_clip_high: 0.2 # dual clip parameters clip_ratio_c: 3.0 - # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl + + # To be deprecated in favor of off_policy_correction.tis_ratio_type = "token" + # and "token_tis_ratio_clip_high" tis_imp_ratio_cap: -1.0 use_tis: false + + # references + # - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + # - https://fengyao.notion.site/off-policy-rl + off_policy_correction: + # type of importance sampling ratio to use for ppo loss correction + # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy}) + tis_ratio_type: null # null, "token", "sequence" + + # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type + token_tis_ratio_clip_high: 2.0 + # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type + sequence_tis_ratio_clip_high: 5.0 + + # method of masking out sequences with cumulative importance sampling ratios outside the cap + # "product" masks out sequences with product of importance ratios outside the cap + # "geometric" masks out sequences with geometric mean of importance ratios outside the cap + sequence_mask_metric: null # null, "product", "geometric" + + # used if sequence_mask_metric = "geometric" + # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch + geo_mask_high: 1.01 + geo_mask_low: 0.99 + + # used if sequence_mask_metric = "product" + # values around 0.5-2.0 are recommended for "product" sequence_mask_metric + product_mask_high: 2.0 + product_mask_low: 0.5 + + # separate from sequence_mask_metric and tis_ratio_type + # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio + # far outside an acceptable range (low and high thresholds) + outlier_token_is_threshold_low: 1e-4 + outlier_token_is_threshold_high: 100 + # SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347) sapo: tau_pos: 1.0 diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index acceccb45..fb8269e97 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -69,12 +69,9 @@ def get_rank(self) -> int: def all_reduce(self, data: DataT, op="mean") -> DataT: """Perform all_reduce across all processes""" - assert op in ("mean", "max", "sum") + assert op in ("mean", "max", "sum", "min") if isinstance(data, dict): - ret = {} - for k, v in data.items(): - ret[k] = self.all_reduce(v, op) - return ret + return {k: self.all_reduce(v, op) for k, v in data.items()} else: is_tensor = True if not isinstance(data, torch.Tensor): @@ -86,7 +83,13 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: data = data.to(torch.cuda.current_device()) if op == "mean": data /= self.world_size - dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM) + dist.all_reduce(data, op=dist.ReduceOp.SUM) + elif op == "max": + dist.all_reduce(data, op=dist.ReduceOp.MAX) + elif op == "min": + dist.all_reduce(data, op=dist.ReduceOp.MIN) + elif op == "sum": + dist.all_reduce(data, op=dist.ReduceOp.SUM) if is_cpu_tensor: data = data.cpu() return data.item() if not is_tensor else data diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index faa578f65..9b27112c4 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -579,12 +579,17 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks, logprobs, ) - # sanity check for tis - if self.cfg.trainer.algorithm.use_tis: + + # sanity check for off_policy_correction + off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction + tis_ratio_type = off_policy_correction.tis_ratio_type + sequence_mask_metric = off_policy_correction.sequence_mask_metric + if tis_ratio_type is not None or sequence_mask_metric is not None: assert ( rollout_logprobs_tensor is not None - ), "expected non-null rollout logprobs tensor with `trainer.algorithm.use_tis` as `True`" + ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled" assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses" + training_input = TrainingInputBatch( { "sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d814aedd7..c192e8e4e 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union, TypedDict, NotRequired import numpy as np import ray @@ -196,6 +196,10 @@ def ppo_critic_loss( return 0.5 * loss, clipfrac +class LossMetrics(TypedDict, total=False): + clip_ratio: NotRequired[float] + + # Shared registry actor class for both policy loss and advantage estimator registries @ray.remote class RegistryActor: @@ -552,6 +556,252 @@ def _safe_exp_delta(delta: torch.Tensor, clip: float = 20.0, out_dtype=None) -> return y.to(out_dtype or delta.dtype) +def compute_tis_ratio( + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + tis_ratio_type: str, + off_policy_correction: DictConfig, +) -> Tuple[torch.Tensor, dict]: + """ + Compute truncated importance sampling (TIS) ratio for off policy correction. + + Args: + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + tis_ratio_type: Type of TIS ratio ("token" or "sequence"). + off_policy_correction: Off-policy correction config containing cap values. + + Returns: + Tuple of (tis_ratio, metrics): + - tis_ratio: Tensor (float) to multiply with the loss + - metrics: Dict with masking statistics + + Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + """ + # Compute token-level importance ratio: pi_old / pi_rollout + # In log space: old_log_probs - rollout_logprobs + token_tis_log_ratio = old_log_probs - rollout_logprobs + token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + + metrics = {} + if tis_ratio_type == "token": + token_tis_ratio_cap = off_policy_correction.token_tis_ratio_clip_high + # Compute proportion of tokens capped + tokens_capped = (token_tis_ratio > token_tis_ratio_cap) & (loss_mask > 0) + total_tokens = (loss_mask > 0).sum() + metrics["tis_token_clip_high_ratio"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item() + return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap).detach(), metrics + elif tis_ratio_type == "sequence": + # Compute sequence-level importance ratio as product of token ratios (sum of log ratios) + seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) + seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + seq_tis_ratio_cap = off_policy_correction.sequence_tis_ratio_clip_high + # Compute proportion of sequences capped + num_sequences = seq_tis_ratio.shape[0] + seqs_capped = (seq_tis_ratio > seq_tis_ratio_cap).sum() + metrics["tis_seq_clip_high_ratio"] = (seqs_capped / num_sequences).detach().item() + return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap).detach(), metrics + else: + raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") + + +def compute_outlier_token_mask( + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + off_policy_correction: DictConfig, +) -> Tuple[torch.Tensor, dict]: + """ + Compute outlier token mask that masks out sequences with any token having + importance ratio outside acceptable bounds. + + This is applied independently of TIS ratio type or sequence mask type, + whenever off policy correction is enabled. + + Args: + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + off_policy_correction: Off-policy correction config containing threshold values. + + Returns: + Tuple of (outlier_mask, metrics): + - outlier_mask: Tensor (bool) to mask out sequences with any token having importance ratio outside acceptable bounds + - metrics: Dict with masking statistics + """ + metrics = {} + # Compute token-level importance ratio + token_tis_log_ratio = old_log_probs - rollout_logprobs + token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + + # Check per-token bounds + token_mask_low = off_policy_correction.outlier_token_is_threshold_low + token_mask_high = off_policy_correction.outlier_token_is_threshold_high + token_over_high = (token_tis_ratio > token_mask_high) & (loss_mask > 0) + token_under_low = (token_tis_ratio < token_mask_low) & (loss_mask > 0) + token_in_bounds = ~token_over_high & ~token_under_low + + # A sequence is valid if all tokens are in bounds (considering only masked positions) + all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True) + + # Compute metrics + num_sequences = float(all_tokens_valid.shape[0]) + # Sequence has any token over high threshold + seq_has_over_high = token_over_high.any(dim=-1) + # Sequence has any token under low threshold + seq_has_under_low = token_under_low.any(dim=-1) + + metrics["outlier_seq_masked_ratio"] = ((~all_tokens_valid.squeeze(-1)).sum() / num_sequences).detach().item() + metrics["outlier_seq_over_high_ratio"] = (seq_has_over_high.sum() / num_sequences).detach().item() + metrics["outlier_seq_under_low_ratio"] = (seq_has_under_low.sum() / num_sequences).detach().item() + + return all_tokens_valid.float(), metrics + + +def compute_sequence_mask( + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + sequence_mask_metric: str, + off_policy_correction: DictConfig, +) -> Tuple[torch.Tensor, dict]: + """ + Compute sequence mask for off policy correction. + + This masks out sequences with importance ratios that fall outside acceptable bounds, + helping to filter out off-policy samples that may destabilize training. + + Args: + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + sequence_mask_metric: Metric to use for sequence masking ("geometric" or "product"). + off_policy_correction: Off-policy correction config containing cap values. + + Returns: + Tuple of (sequence_mask, metrics): + - sequence_mask: Tensor (float) to multiply with the loss + - metrics: Dict with masking statistics + """ + # Compute token-level importance ratio + token_tis_log_ratio = old_log_probs - rollout_logprobs + metrics = {} + + if sequence_mask_metric == "geometric": + # Compute geometric mean of importance ratios per sequence + num_tokens = loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) + seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) + geo_mean_ratio = _safe_exp_delta(seq_tis_log_ratio / num_tokens, clip=20.0, out_dtype=old_log_probs.dtype) + geo_cap_high = off_policy_correction.geo_mask_high + geo_cap_low = off_policy_correction.geo_mask_low + seq_over_high = geo_mean_ratio > geo_cap_high + seq_under_low = geo_mean_ratio < geo_cap_low + geo_sequence_mask = ~seq_over_high & ~seq_under_low + + num_sequences = float(geo_mean_ratio.shape[0]) + metrics["geo_sequence_mask_masked_ratio"] = ((~geo_sequence_mask).sum() / num_sequences).detach().item() + metrics["geo_sequence_mask_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["geo_sequence_mask_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() + + return geo_sequence_mask.float(), metrics + elif sequence_mask_metric == "product": + # Mask out sequences with product of importance ratios outside the cap + seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) + seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + seq_cap_high = off_policy_correction.product_mask_high + seq_cap_low = off_policy_correction.product_mask_low + seq_over_high = seq_tis_ratio > seq_cap_high + seq_under_low = seq_tis_ratio < seq_cap_low + seq_in_bounds = ~seq_over_high & ~seq_under_low + + num_sequences = float(seq_tis_ratio.shape[0]) + metrics["product_sequence_mask_masked_ratio"] = ((~seq_in_bounds).sum() / num_sequences).detach().item() + metrics["product_sequence_mask_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["product_sequence_mask_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() + + return seq_in_bounds.float(), metrics + else: + raise ValueError(f"Unknown sequence_mask_metric: {sequence_mask_metric}") + + +def compute_off_policy_correction( + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + off_policy_correction: DictConfig, +) -> Tuple[Optional[torch.Tensor], dict, torch.Tensor]: + """ + Compute TIS ratio, sequence mask, and outlier token mask for off policy correction. + + This is a convenience function that combines compute_tis_ratio, compute_sequence_mask, + and compute_outlier_token_mask. + + Args: + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + off_policy_correction: Off-policy correction config. + + Returns: + Tuple of (tis_ratio, metrics, loss_mask): + - tis_ratio: Tensor (float) to multiply with the loss + - metrics: Dict with masking statistics + - loss_mask: Mask indicating valid tokens after applying off policy correction + + References: + - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + - https://fengyao.notion.site/off-policy-rl + """ + tis_ratio_type = off_policy_correction.tis_ratio_type + sequence_mask_metric = off_policy_correction.sequence_mask_metric + + # Check if TIS ratio correction is enabled + apply_tis = tis_ratio_type is not None + # Check if sequence mask is enabled + apply_sequence_mask = sequence_mask_metric is not None + + # Early return if no correction needed + if not apply_tis and not apply_sequence_mask: + return None, {}, loss_mask + + is_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=old_log_probs.dtype) + metrics = {} + metrics["is_ratio_mean"] = masked_mean(is_ratio, loss_mask).mean().detach().item() + metrics["is_ratio_std"] = (is_ratio * loss_mask).std().detach().item() + metrics["is_ratio_max"] = (is_ratio * loss_mask).max().detach().item() + metrics["is_ratio_min"] = (is_ratio * loss_mask).min().detach().item() + + # Apply outlier token mask whenever off policy correction is enabled + # This rejects sequences with any token having importance ratio outside acceptable bounds + outlier_mask, outlier_metrics = compute_outlier_token_mask( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + ) + loss_mask = loss_mask * outlier_mask + metrics.update(outlier_metrics) + + # Initialize tis_ratio to None (only set if TIS is enabled) + tis_ratio = None + + # Apply TIS ratio if enabled + if apply_tis: + tis_ratio, tis_metrics = compute_tis_ratio( + old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, off_policy_correction + ) + metrics.update(tis_metrics) + + # Apply sequence mask if enabled + if apply_sequence_mask: + sequence_mask, sequence_mask_metrics = compute_sequence_mask( + old_log_probs, rollout_logprobs, loss_mask, sequence_mask_metric, off_policy_correction + ) + loss_mask = loss_mask * sequence_mask + metrics.update(sequence_mask_metrics) + + return tis_ratio, metrics, loss_mask + + @register_policy_loss(PolicyLossType.REGULAR) @register_policy_loss(PolicyLossType.DUAL_CLIP) def ppo_policy_loss( @@ -581,17 +831,20 @@ def ppo_policy_loss( clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - if config.use_tis: - from loguru import logger as logger_ + loss_metrics = LossMetrics(clip_ratio=clip_ratio) - logger_.debug(f"Using TIS with dtype: {rollout_logprobs.dtype}") - # Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl - tis_imp_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype) - tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) - loss = loss * tis_imp_ratio + # apply off policy correction + off_policy_correction = config.off_policy_correction + if rollout_logprobs is not None: + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + ) + if tis_ratio is not None: + loss = loss * tis_ratio + loss_metrics.update(off_policy_correction_metrics) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) - return loss, clip_ratio + return loss, loss_metrics @register_policy_loss(PolicyLossType.SAPO) @@ -653,13 +906,20 @@ def gate_function(x, tau): # compute policy gradient loss loss = -gates * advantages + # apply off policy correction + off_policy_correction = config.off_policy_correction + loss_metrics = LossMetrics(clip_ratio=0.0) + if rollout_logprobs is not None: + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + ) + if tis_ratio is not None: + loss = loss * tis_ratio + loss_metrics.update(off_policy_correction_metrics) # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) - # SAPO does not use clipping, so we set clip_ratio to 0.0 for compatibility - clip_ratio = 0.0 - - return loss, clip_ratio + return loss, loss_metrics @register_policy_loss(PolicyLossType.GSPO) @@ -717,9 +977,20 @@ def gspo_policy_loss( # Compute clipping ratio for monitoring clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() + # apply off policy correction + loss_metrics = LossMetrics(clip_ratio=clip_ratio) + off_policy_correction = config.off_policy_correction + if rollout_logprobs is not None: + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + ) + if tis_ratio is not None: + loss = loss * tis_ratio + loss_metrics.update(off_policy_correction_metrics) + loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) - return loss, clip_ratio + return loss, loss_metrics @register_policy_loss(PolicyLossType.CISPO) @@ -747,8 +1018,19 @@ def compute_policy_loss_cispo( is_clipped = (ratio < 1 - config.cispo.cispo_eps_clip_low) | (ratio > 1 + config.cispo.cispo_eps_clip_high) clip_ratio = masked_mean(is_clipped.float(), loss_mask).mean().detach().item() + # apply off policy correction + off_policy_correction = config.off_policy_correction + loss_metrics = LossMetrics(clip_ratio=clip_ratio) + if rollout_logprobs is not None: + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + ) + if tis_ratio is not None: + loss = loss * tis_ratio + loss_metrics.update(off_policy_correction_metrics) + loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) - return loss, clip_ratio + return loss, loss_metrics @register_policy_loss(PolicyLossType.CLIP_COV) @@ -808,6 +1090,7 @@ def compute_policy_loss_clip_cov( # Apply correction mask to losses pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr + pg_loss = reduce_loss( loss=pg_losses, loss_mask=loss_mask, @@ -815,7 +1098,7 @@ def compute_policy_loss_clip_cov( max_seq_len=config.max_seq_len, ) - return pg_loss, clip_frac.item() + return pg_loss, LossMetrics(clip_ratio=clip_frac.item()) @register_policy_loss(PolicyLossType.KL_COV) @@ -875,7 +1158,7 @@ def compute_policy_loss_kl_cov( ) # NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0 - return pg_loss, 0.0 + return pg_loss, LossMetrics(clip_ratio=0.0) def reduce_loss( diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 349d5778a..8c3039c54 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -281,26 +281,53 @@ def validate_cfg(cfg: DictConfig): algorithm_config.kl_estimator_type = "k3" cfg.trainer.algorithm = algorithm_config + # TODO (erictang000): remove this after deprecation period if cfg.trainer.algorithm.use_tis: - if cfg.trainer.algorithm.tis_imp_ratio_cap <= 0: - raise ValueError( - f"If `trainer.algorithm.use_tis` is `True` then `cfg.trainer.algorithm.tis_imp_ratio_cap` " - f"should be > 0, got {cfg.trainer.algorithm.tis_imp_ratio_cap }" - ) + logger.warning( + f"`trainer.algorithm.use_tis` is deprecated. Setting `trainer.algorithm.off_policy_correction` to `token` instead." + f"with `token_tis_ratio_clip_high`={cfg.trainer.algorithm.tis_imp_ratio_cap}" + ) + cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "token" + cfg.trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high = cfg.trainer.algorithm.tis_imp_ratio_cap + + # off_policy_correction config validation + off_policy_correction = cfg.trainer.algorithm.off_policy_correction + tis_ratio_type = off_policy_correction.tis_ratio_type + sequence_mask_metric = off_policy_correction.sequence_mask_metric + + uses_off_policy_correction = tis_ratio_type is not None or sequence_mask_metric is not None + + if uses_off_policy_correction: + # Validate tis_ratio_type + if tis_ratio_type: + assert tis_ratio_type in [ + "token", + "sequence", + ], f"`tis_ratio_type` must be 'None', 'token', or 'sequence', got {tis_ratio_type}" + + # Validate sequence_mask_metric + if sequence_mask_metric: + assert sequence_mask_metric in [ + "product", + "geometric", + ], f"`sequence_mask_metric` must be 'product', or 'geometric', got {sequence_mask_metric}" + + # Ensure logprobs are enabled for rollout correction if cfg.generator.sampling_params.logprobs is None: logger.warning( - "`generator.sampling_params.logprobs` is `None` but `trainer.algorithm.use_tis` is `True`." + "`generator.sampling_params.logprobs` is `None` but off_policy_correction is enabled." " Setting `logprobs` to `True`." ) - # just set to 0 for better user exp cfg.generator.sampling_params.logprobs = 0 if cfg.generator.backend == "sglang": - raise NotImplementedError("`trainer.algorithm.use_tis` doesn't support Sglang backend, please use vLLM") - assert cfg.trainer.algorithm.policy_loss_type in [ - "regular", - "dual_clip", - ], "TIS is only implemented for regular and dual_clip policy loss types" + raise NotImplementedError( + "`trainer.algorithm.off_policy_correction` doesn't support Sglang backend, please use vLLM" + ) + if cfg.trainer.algorithm.policy_loss_type in ["clip_cov", "kl_cov"]: + raise NotImplementedError( + "`trainer.algorithm.off_policy_correction` doesn't support clip_cov or kl_cov policy loss types" + ) if cfg.trainer.policy.model.lora.rank > 0: # LoRA enabled diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index d84673e66..b3d14c65f 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -219,7 +219,7 @@ def loss_func(logits, data): action_log_probs = token_logprobs[:, -num_actions:] # policy loss should be calculated based on the selected token logprobs - policy_loss, clip_ratio = self.policy_loss_fn( + policy_loss, loss_metrics = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages, @@ -256,9 +256,10 @@ def loss_func(logits, data): "final_loss": loss.detach().item(), "policy_loss": policy_loss.detach().item(), "policy_entropy": entropy.detach().item(), - "ppo_clip_ratio": clip_ratio, "policy_kl": kl_loss.detach().item(), } + for k, v in loss_metrics.items(): + metrics["loss_metrics/" + k] = v return loss, metrics def forward_step(batch_iter, model): diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 080b9bb42..e308b4526 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -30,7 +30,7 @@ from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl_train.training_batch import TrainingOutputBatch -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics from skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -583,13 +583,11 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # within a DP group, metrics are already the same across all workers - we then just all reduce across # the whole world size to get the metrics for the global micro batch for i, metrics in enumerate(metrics_list): - status = { - "final_loss": metrics["final_loss"], - "policy_loss": metrics["policy_loss"], - "policy_lr": self.optimizer.param_groups[0]["lr"], - "ppo_clip_ratio": metrics["ppo_clip_ratio"], - "policy_entropy": metrics["policy_entropy"], - } + status = {} + for k, v in metrics.items(): + status[k] = v + + status["policy_lr"] = self.optimizer.param_groups[0]["lr"] if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = metrics["policy_kl"] @@ -600,7 +598,8 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # attach response_length status["response_length"] = micro_buffer[i]["num_actions"] - status = self.strategy.all_reduce(status) + status = all_reduce_metrics(status, self.strategy) + status_list.append(status) for k, v in status.items(): all_metrics[k].append(v) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 361a0777e..f062f9b20 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -32,7 +32,7 @@ from loguru import logger from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics from skyrl_train.dataset.replay_buffer import Experience from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient @@ -663,7 +663,7 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> ) # loss function # TODO: recompute advantages - policy_loss, clip_ratio = self.policy_loss_fn( + policy_loss, loss_metrics = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages, @@ -704,10 +704,11 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> status = { "final_loss": loss.item(), "policy_loss": policy_loss.item(), - "ppo_clip_ratio": clip_ratio, "policy_entropy": entropy.item(), "response_length": num_actions, } + for k, v in loss_metrics.items(): + status["loss_metrics/" + k] = v if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() @@ -741,7 +742,7 @@ def record_status(status: Dict[str, float]): # for DP # TODO (sumanthrh): this assumes all workers are data parallel. # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. - status = self.strategy.all_reduce(status) + status = all_reduce_metrics(status, self.strategy) # weighted mean for kl # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. @@ -817,7 +818,6 @@ def record_status(status: Dict[str, float]): torch.distributed.barrier() # not needed beyond status logging all_metrics.pop("response_length", None) - status_mean = reduce_metrics(all_metrics) status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 897d032ea..ebc25f746 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -2,6 +2,7 @@ from skyrl_train.dataset.replay_buffer import Experience from typing import List, Dict from skyrl_train.training_batch import TrainingInputBatch +from skyrl_train.distributed.strategy import DistributedStrategy def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: @@ -12,10 +13,28 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: for k, v in metrics.items(): assert len(v) > 0, f"No metrics for key {k}" assert all(isinstance(x, (int, float)) for x in v), f"Metrics for key {k} are not all numbers" - reduced_metrics[k] = sum(v) / len(v) + if k.endswith("_max"): + reduced_metrics[k] = max(v) + elif k.endswith("_min"): + reduced_metrics[k] = min(v) + else: + reduced_metrics[k] = sum(v) / len(v) return reduced_metrics +def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]: + """All reduce metrics across all processes.""" + min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} + max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} + mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics} + status_mean = strategy.all_reduce(mean_metrics, op="mean") + status_min = strategy.all_reduce(min_metrics, op="min") + status_max = strategy.all_reduce(max_metrics, op="max") + status_mean.update(status_min) + status_mean.update(status_max) + return status_mean + + class BatchIterator: """A simple iterator to yield micro batches of data from the training batch.""" diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index f5904b595..f3aa2a0e8 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -8,7 +8,21 @@ import torch from omegaconf import DictConfig -from skyrl_train.utils.ppo_utils import PolicyLossRegistry, masked_mean +from skyrl_train.utils.ppo_utils import ( + PolicyLossRegistry, + masked_mean, + compute_tis_ratio, + compute_sequence_mask, + compute_outlier_token_mask, + compute_off_policy_correction, +) + +NULL_OFF_POLICY_CORR = { + "tis_ratio_type": None, + "sequence_mask_metric": None, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, +} # Adapted a good test from NeMO-RL @@ -33,7 +47,7 @@ def test_policy_loss_dual_clip(): "policy_loss_type": "dual_clip", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -86,7 +100,7 @@ def test_policy_loss_cispo(): "policy_loss_type": "cispo", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -164,7 +178,7 @@ def test_policy_loss_reduction_modes(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -176,7 +190,7 @@ def test_policy_loss_reduction_modes(): "policy_loss_type": "regular", "loss_reduction": "sequence_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -249,7 +263,7 @@ def test_policy_loss_reduction_edge_cases(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -261,7 +275,7 @@ def test_policy_loss_reduction_edge_cases(): "policy_loss_type": "regular", "loss_reduction": "sequence_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -347,7 +361,7 @@ def test_gspo_importance_sampling_levels(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) ppo_loss_fn = PolicyLossRegistry.get("regular") @@ -362,7 +376,7 @@ def test_gspo_importance_sampling_levels(): "policy_loss_type": "gspo", "loss_reduction": "sequence_mean", # GSPO recommended reduction "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) gspo_loss_fn = PolicyLossRegistry.get("gspo") @@ -469,6 +483,7 @@ def test_clip_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "clip_cov": {"clip_ratio": 0.5, "clip_cov_lb": -5.0, "clip_cov_ub": 5.0}, # Large ratio for testing + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -476,11 +491,12 @@ def test_clip_cov_policy_loss(): clip_cov_fn = PolicyLossRegistry.get("clip_cov") # Calculate loss - loss, clip_frac = clip_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) + loss, loss_metrics = clip_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) + clip_ratio = loss_metrics["clip_ratio"] # Basic sanity checks assert torch.isfinite(loss), "Loss should be finite" - assert 0 <= clip_frac <= 1, f"Clip fraction should be between 0 and 1, got {clip_frac}" + assert 0 <= clip_ratio <= 1, f"Clip ratio should be between 0 and 1, got {clip_ratio}" # Compare with regular PPO (should be different due to covariance correction) regular_config = DictConfig( @@ -490,12 +506,12 @@ def test_clip_cov_policy_loss(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) regular_fn = PolicyLossRegistry.get("regular") - regular_loss, regular_clip_frac = regular_fn(log_probs, old_log_probs, advantages, regular_config, loss_mask) + regular_loss, regular_loss_metrics = regular_fn(log_probs, old_log_probs, advantages, regular_config, loss_mask) # Clip-Cov should give different results due to covariance-based correction assert not torch.allclose( @@ -531,6 +547,7 @@ def test_kl_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "kl_cov": {"kl_cov_frac": 0.5, "ppo_kl_coef": 1.0}, # Apply KL to 50% of tokens + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -538,11 +555,11 @@ def test_kl_cov_policy_loss(): kl_cov_fn = PolicyLossRegistry.get("kl_cov") # Calculate loss - loss, clip_frac = kl_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) + loss, loss_metrics = kl_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) # Basic sanity checks assert torch.isfinite(loss), "Loss should be finite" - assert clip_frac == 0.0, "KL-Cov should return 0.0 for clipfrac value" + assert loss_metrics["clip_ratio"] == 0.0, "KL-Cov should return 0.0 for clip_ratio value" # Compare with regular PPO (should be different due to KL regularization) regular_config = DictConfig( @@ -552,7 +569,7 @@ def test_kl_cov_policy_loss(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -585,13 +602,14 @@ def test_sapo_policy_loss_basic(): "loss_reduction": "sequence_mean", "max_seq_len": 4, "sapo": {"tau_pos": 1.0, "tau_neg": 2.0}, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) loss_fn = PolicyLossRegistry.get("sapo") # Actual SAPO loss - actual_loss, actual_clip_ratio = loss_fn( + actual_loss, loss_metrics = loss_fn( log_probs=log_probs, old_log_probs=old_log_probs, advantages=advantages, @@ -620,4 +638,575 @@ def gate_function(x, tau): torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8) # SAPO should always report clip_ratio = 0.0 - assert actual_clip_ratio == 0.0 + assert loss_metrics["clip_ratio"] == 0.0 + + +# ============================================================================ +# Rollout Correction Tests +# ============================================================================ + + +def test_compute_tis_ratio_token_level(): + """Tests token-level TIS ratio computation with capping.""" + device = "cpu" + + # old_log_probs - rollout_logprobs gives the log importance ratio + # Token ratios: exp([0.5, -0.5, 1.0]) = [1.6487, 0.6065, 2.7183] + old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "token", + "token_tis_ratio_clip_high": 2.0, + } + ) + + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "token", config) + + # Expected: [1.6487, 0.6065, 2.0] (third token capped at 2.0) + expected = torch.tensor([[1.6487, 0.6065, 2.0]], device=device) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # One token out of 3 was capped + assert "tis_token_clip_high_ratio" in metrics + assert abs(metrics["tis_token_clip_high_ratio"] - 1 / 3) < 0.01 + + +def test_compute_tis_ratio_sequence_level(): + """Tests sequence-level TIS ratio computation with capping.""" + device = "cpu" + + # Token log ratios: [0.5, -0.5, 1.0] + # Sequence log ratio (sum of masked): 0.5 + (-0.5) + 1.0 = 1.0 + # Sequence ratio: exp(1.0) = 2.7183 + old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "sequence", + "sequence_tis_ratio_clip_high": 5.0, + } + ) + + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Expected: exp(1.0) = 2.7183, shape [batch, 1] for sequence-level + expected = torch.tensor([[2.7183]], device=device) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # No sequence was capped (2.7183 < 5.0) + assert "tis_seq_clip_high_ratio" in metrics + assert metrics["tis_seq_clip_high_ratio"] == 0.0 + + +def test_compute_tis_ratio_sequence_level_with_cap(): + """Tests sequence-level TIS ratio capping.""" + device = "cpu" + + # Token log ratios: [1.0, 1.0, 1.0] + # Sequence log ratio: 3.0 + # Sequence ratio: exp(3.0) = 20.09, should be capped at 5.0 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "sequence", + "sequence_tis_ratio_clip_high": 5.0, + } + ) + + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Expected: capped at 5.0, shape [batch, 1] for sequence-level + expected = torch.tensor([[5.0]], device=device) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # One sequence out of 1 was capped + assert "tis_seq_clip_high_ratio" in metrics + assert metrics["tis_seq_clip_high_ratio"] == 1.0 + + +def test_compute_tis_ratio_with_mask(): + """Tests that loss_mask correctly excludes tokens from sequence-level computation.""" + device = "cpu" + + # Token log ratios: [0.5, -0.5, 1.0] + # With mask [1, 0, 1], sequence log ratio = 0.5 + 1.0 = 1.5 + old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 0.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "sequence", + "sequence_tis_ratio_clip_high": 10.0, + } + ) + + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Expected: exp(1.5) = 4.4817, shape [batch, 1] for sequence-level + expected_val = torch.exp(torch.tensor(1.5)) + expected = expected_val.reshape(1, 1) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # No sequence was capped (4.4817 < 10.0) + assert "tis_seq_clip_high_ratio" in metrics + assert metrics["tis_seq_clip_high_ratio"] == 0.0 + + +def test_compute_sequence_mask_geometric(): + """Tests geometric sequence mask computation.""" + device = "cpu" + + # Token log ratios: [0.1, -0.1, 0.0] -> sum = 0.0, geometric mean = exp(0/3) = 1.0 + old_log_probs = torch.tensor([[-1.0, -1.1, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.1, -1.0, -1.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, + } + ) + + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + + # Geometric mean ≈ 1.0, which is within [0.9, 1.1], so mask should be 1.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked + assert metrics["geo_sequence_mask_masked_ratio"] == 0.0 + assert metrics["geo_sequence_mask_over_high_ratio"] == 0.0 + assert metrics["geo_sequence_mask_under_low_ratio"] == 0.0 + + +def test_compute_sequence_mask_geometric_rejects(): + """Tests geometric sequence mask correctly rejects sequences outside bounds.""" + device = "cpu" + + # Token log ratios: [0.5, 0.5, 0.5] -> sum = 1.5, geometric mean = exp(1.5/3) = exp(0.5) ≈ 1.6487 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, + } + ) + + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + + # Geometric mean ≈ 1.6487, which is outside [0.9, 1.1], so mask should be 0.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[0.0]], device=device) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) + # One sequence masked, over high cap + assert metrics["geo_sequence_mask_masked_ratio"] == 1.0 + assert metrics["geo_sequence_mask_over_high_ratio"] == 1.0 + assert metrics["geo_sequence_mask_under_low_ratio"] == 0.0 + + +def test_compute_sequence_mask_product(): + """Tests product sequence mask computation.""" + device = "cpu" + + # Token log ratios: [0.2, 0.1, 0.0] -> sum = 0.3, seq ratio = exp(0.3) ≈ 1.35 + old_log_probs = torch.tensor([[-1.0, -1.1, -1.2]], device=device) + rollout_logprobs = torch.tensor([[-1.2, -1.2, -1.2]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "sequence_mask_metric": "product", + "product_mask_high": 2.0, + "product_mask_low": 0.5, + } + ) + + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "product", config) + + # Sequence ratio ≈ 1.35, which is within [0.5, 2.0] + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked + assert metrics["product_sequence_mask_masked_ratio"] == 0.0 + assert metrics["product_sequence_mask_over_high_ratio"] == 0.0 + assert metrics["product_sequence_mask_under_low_ratio"] == 0.0 + + +def test_compute_sequence_mask_product_rejects_by_seq_ratio(): + """Tests product sequence mask rejects when product ratio is out of bounds.""" + device = "cpu" + + # Token log ratios: [1.0, 1.0, 1.0] -> sum = 3.0, seq ratio = exp(3.0) ≈ 20.09 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "sequence_mask_metric": "product", + "product_mask_high": 2.0, + "product_mask_low": 0.5, + } + ) + + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "product", config) + + # Sequence ratio ≈ 20.09, which is outside [0.5, 2.0], so mask should be 0.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[0.0]], device=device) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) + # One sequence masked, over high cap + assert metrics["product_sequence_mask_masked_ratio"] == 1.0 + assert metrics["product_sequence_mask_over_high_ratio"] == 1.0 + assert metrics["product_sequence_mask_under_low_ratio"] == 0.0 + + +def test_compute_outlier_token_mask_masks_by_token_bounds(): + """Tests outlier token mask rejects when a token ratio is out of bounds.""" + device = "cpu" + + # Token log ratios: [0.0, 0.0, 5.0] -> token ratios = [1.0, 1.0, 148.4] + # Third token ratio 148.4 > 100.0, so should reject + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.0, -1.0, -6.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, # This should cause masking + } + ) + + outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + + # Token ratio 148.4 > 100.0, so mask should be 0.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[0.0]], device=device) + torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + # One sequence masked, has token over high threshold + assert metrics["outlier_seq_masked_ratio"] == 1.0 + assert metrics["outlier_seq_over_high_ratio"] == 1.0 + assert metrics["outlier_seq_under_low_ratio"] == 0.0 + + +def test_compute_outlier_token_mask_accepts_in_bounds(): + """Tests outlier token mask accepts when all token ratios are in bounds.""" + device = "cpu" + + # Token log ratios: [0.5, -0.5, 0.0] -> token ratios = [1.65, 0.61, 1.0] + # All token ratios within [1e-4, 100.0], so should accept + old_log_probs = torch.tensor([[-1.0, -1.5, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + + # All token ratios in bounds, so mask should be 1.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked + assert metrics["outlier_seq_masked_ratio"] == 0.0 + assert metrics["outlier_seq_over_high_ratio"] == 0.0 + assert metrics["outlier_seq_under_low_ratio"] == 0.0 + + +def test_compute_outlier_token_mask_respects_loss_mask(): + """Tests outlier token mask ignores out-of-bounds tokens that are masked.""" + device = "cpu" + + # Token log ratios: [0.0, 0.0, 5.0] -> token ratios = [1.0, 1.0, 148.4] + # Third token ratio 148.4 > 100.0, but it's masked, so should accept + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.0, -1.0, -6.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 0.0]], device=device) # Third token masked + + config = DictConfig( + { + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + + # Third token is masked, so even though ratio is out of bounds, sequence should be accepted + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked (the out-of-bounds token was in a masked position) + assert metrics["outlier_seq_masked_ratio"] == 0.0 + + +def test_compute_off_policy_correction_null_configs(): + """Tests that compute_off_policy_correction returns None tis_ratio when both configs are null.""" + device = "cpu" + + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": None, + "sequence_mask_metric": None, + } + ) + + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config + ) + + # Should return None tis_ratio (early return) and empty metrics + assert tis_ratio is None + assert metrics == {} + + +def test_compute_off_policy_correction_tis_only(): + """Tests compute_off_policy_correction with only TIS enabled.""" + device = "cpu" + + # Token log ratios: [0.5, 0.5, 0.5] -> token ratios = [1.6487, 1.6487, 1.6487] + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "token", + "token_tis_ratio_clip_high": 2.0, + "sequence_mask_metric": None, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config + ) + + # Expected tis_ratio: 1.6487 (no capping needed) + expected_tis_ratio = torch.exp(torch.tensor(0.5)) + torch.testing.assert_close( + tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4 + ) + # Check metrics are populated + assert "is_ratio_mean" in metrics + assert "tis_token_clip_high_ratio" in metrics + + +def test_compute_off_policy_correction_sequence_mask_only(): + """Tests compute_off_policy_correction with only geometric sequence mask enabled.""" + device = "cpu" + + # Token log ratios: [0.0, 0.0, 0.0] -> geometric mean = 1.0 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": None, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config + ) + + # Geometric mean = 1.0, within bounds, so loss_mask unchanged + # tis_ratio is None since tis_ratio_type is None + assert tis_ratio is None + torch.testing.assert_close(new_loss_mask, loss_mask, rtol=1e-3, atol=1e-4) + # Check metrics are populated + assert "is_ratio_mean" in metrics + assert "geo_sequence_mask_masked_ratio" in metrics + + +def test_compute_off_policy_correction_both_enabled(): + """Tests compute_off_policy_correction with both TIS and geometric sequence mask enabled.""" + device = "cpu" + + # Token log ratios: [0.1, 0.1, 0.1] -> token ratios ≈ [1.105, 1.105, 1.105] + # Geometric mean ≈ 1.105 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.1, -1.1, -1.1]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "token", + "token_tis_ratio_clip_high": 2.0, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.2, + "geo_mask_low": 0.8, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config + ) + + # TIS ratio ≈ 1.105, geometric mean ≈ 1.105 (within bounds, mask=1) + expected_tis_ratio = torch.exp(torch.tensor(0.1)) + torch.testing.assert_close( + tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4 + ) + # Check metrics from both TIS and sequence mask are populated + assert "tis_token_clip_high_ratio" in metrics + assert "geo_sequence_mask_masked_ratio" in metrics + + +def test_compute_off_policy_correction_sequence_mask_zeros_loss(): + """Tests that sequence mask can zero out the loss_mask entirely.""" + device = "cpu" + + # Token log ratios: [1.0, 1.0, 1.0] -> geometric mean = exp(1.0) ≈ 2.718 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": None, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config + ) + + # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss_mask should be zeroed + expected_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device) + torch.testing.assert_close(new_loss_mask, expected_mask, rtol=1e-3, atol=1e-4) + # Check that the sequence mask metrics show sequence mask happened + assert metrics["geo_sequence_mask_masked_ratio"] == 1.0 + + +def test_ppo_policy_loss_with_off_policy_correction(): + """Integration test for PPO policy loss with rollout correction enabled.""" + device = "cpu" + + advantages = torch.tensor([[1.0, -1.0, 0.5]], device=device) + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + log_probs = torch.tensor([[-1.1, -0.9, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.05, -1.05, -1.05]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "eps_clip_low": 0.2, + "eps_clip_high": 0.2, + "policy_loss_type": "regular", + "loss_reduction": "token_mean", + "max_seq_len": 4, + "off_policy_correction": { + "tis_ratio_type": "token", + "token_tis_ratio_clip_high": 2.0, + "sequence_mask_metric": None, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + }, + } + ) + + loss_fn = PolicyLossRegistry.get("regular") + + # Loss with rollout correction + loss_with_correction, _ = loss_fn( + log_probs=log_probs, + old_log_probs=old_log_probs, + advantages=advantages, + config=config, + loss_mask=loss_mask, + rollout_logprobs=rollout_logprobs, + ) + + # Loss without rollout correction + config_no_correction = DictConfig( + { + "eps_clip_low": 0.2, + "eps_clip_high": 0.2, + "policy_loss_type": "regular", + "loss_reduction": "token_mean", + "max_seq_len": 4, + "off_policy_correction": { + "tis_ratio_type": None, + "sequence_mask_metric": None, + }, + } + ) + + loss_without_correction, _ = loss_fn( + log_probs=log_probs, + old_log_probs=old_log_probs, + advantages=advantages, + config=config_no_correction, + loss_mask=loss_mask, + rollout_logprobs=rollout_logprobs, + ) + + # TIS correction should modify the loss + assert not torch.allclose(loss_with_correction, loss_without_correction, rtol=1e-3), ( + f"Rollout correction should change the loss: " + f"with={loss_with_correction:.6f} vs without={loss_without_correction:.6f}" + ) + + +def test_compute_tis_ratio_invalid_type(): + """Tests that compute_tis_ratio raises error for invalid tis_ratio_type.""" + device = "cpu" + + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig({"tis_ratio_type": "invalid"}) + + with pytest.raises(ValueError, match="Unknown tis_ratio_type"): + compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "invalid", config) + + +def test_compute_sequence_mask_invalid_type(): + """Tests that compute_sequence_mask raises error for invalid sequence_mask_metric.""" + device = "cpu" + + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig({"sequence_mask_metric": "invalid"}) + + with pytest.raises(ValueError, match="Unknown sequence_mask_metric"): + compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "invalid", config) diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 599f8a3f4..1d8924eb4 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -550,7 +550,12 @@ def create_test_worker(worker_class): def mock_policy_forward_backward(experience, microbatch_weight): policy_forward_backward_calls.append({"microbatch_weight": microbatch_weight}) - return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length} + return { + "policy_loss": 0.5, + "loss_metrics/clip_ratio": 0.1, + "policy_entropy": 2.0, + "response_length": response_length, + } policy_worker.forward_backward = mock_policy_forward_backward policy_worker.optim_step = MagicMock(return_value=None) diff --git a/skyrl-train/tests/cpu/utils/test_ppo_utils.py b/skyrl-train/tests/cpu/utils/test_ppo_utils.py index fb69ce15e..f9744790c 100644 --- a/skyrl-train/tests/cpu/utils/test_ppo_utils.py +++ b/skyrl-train/tests/cpu/utils/test_ppo_utils.py @@ -17,6 +17,7 @@ AdvantageEstimatorRegistry, register_advantage_estimator, PolicyLossRegistry, + LossMetrics, register_policy_loss, compute_reinforce_plus_plus_outcome_advantage, compute_rloo_outcome_advantage, @@ -376,7 +377,7 @@ def test_policy_loss_registry_specific(): @register_policy_loss("test_policy_decorator") def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None): - return torch.tensor(1.5), 0.3 + return torch.tensor(1.5), LossMetrics(clip_ratio=0.3) # Test decorator worked assert "test_policy_decorator" in PolicyLossRegistry.list_available() @@ -385,14 +386,14 @@ def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mas # Test function execution config = DictConfig({"policy_loss_type": "test_policy_decorator"}) - loss, clip_ratio = retrieved( + loss, loss_metrics = retrieved( log_probs=torch.tensor([[0.1]]), old_log_probs=torch.tensor([[0.2]]), advantages=torch.tensor([[1.0]]), config=config, ) assert loss.item() == 1.5 - assert clip_ratio == 0.3 + assert loss_metrics["clip_ratio"] == 0.3 # Test error message includes "Policy loss" with pytest.raises(ValueError, match="Unknown policy loss"): @@ -413,10 +414,10 @@ def test_registry_cross_ray_process(): # Create test functions def test_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None): - return torch.tensor(2.0), 0.5 + return torch.tensor(2.0), LossMetrics(clip_ratio=0.5) def test_policy_loss_2(log_probs, old_log_probs, advantages, config, loss_mask=None): - return torch.tensor(3.0), 0.6 + return torch.tensor(3.0), LossMetrics(clip_ratio=0.6) def test_advantage_estimator(**kwargs): rewards = kwargs["token_level_rewards"] @@ -432,7 +433,7 @@ def test_ray_registry_access(): policy_loss = PolicyLossRegistry.get("cross_process_test") adv_estimator = AdvantageEstimatorRegistry.get("cross_process_adv_test") - loss, clip_ratio = policy_loss( + loss, loss_metrics = policy_loss( log_probs=torch.tensor([[0.1]]), old_log_probs=torch.tensor([[0.2]]), advantages=torch.tensor([[1.0]]), @@ -444,25 +445,25 @@ def test_ray_registry_access(): response_mask=torch.tensor([[1.0, 1.0]]), index=np.array(["0", "0"]), ) - return loss, clip_ratio, adv, ret + return loss, loss_metrics, adv, ret # Run Ray task - loss, clip_ratio, adv, ret = ray.get(test_ray_registry_access.remote()) + loss, loss_metrics, adv, ret = ray.get(test_ray_registry_access.remote()) assert loss.item() == 2.0 - assert clip_ratio == 0.5 + assert loss_metrics["clip_ratio"] == 0.5 assert adv.shape == torch.Size([1, 2]) assert ret.shape == torch.Size([1, 2]) # test that registration works after ray init as well PolicyLossRegistry.register("cross_process_test_2", test_policy_loss_2) - loss_2, clip_ratio_2 = PolicyLossRegistry.get("cross_process_test_2")( + loss_2, loss_metrics_2 = PolicyLossRegistry.get("cross_process_test_2")( log_probs=torch.tensor([[0.1]]), old_log_probs=torch.tensor([[0.2]]), advantages=torch.tensor([[1.0]]), config=DictConfig({"policy_loss_type": "cross_process_test_2"}), ) assert loss_2.item() == 3.0 - assert clip_ratio_2 == 0.6 + assert loss_metrics_2["clip_ratio"] == 0.6 finally: PolicyLossRegistry.reset() AdvantageEstimatorRegistry.reset() diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 592a0518b..028b2d0bd 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -472,6 +472,11 @@ async def test_megatron_train( ) transformer_config_kwargs["num_layers"] = 2 cfg.trainer.policy.megatron_config.transformer_config_kwargs = transformer_config_kwargs + # test off policy correction config propagates correctly + cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "sequence" + cfg.trainer.algorithm.off_policy_correction.sequence_mask_metric = "geometric" + cfg.trainer.algorithm.off_policy_correction.geo_mask_high = 1.02 + cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 # set batch sizes correctly cfg.trainer.train_batch_size = gpus_per_node @@ -501,7 +506,7 @@ async def test_megatron_train( assert isinstance(result, dict), "Result should be a dictionary of training stats" assert "policy_loss" in result assert "policy_lr" in result - assert "ppo_clip_ratio" in result + assert "loss_metrics/clip_ratio" in result assert "policy_entropy" in result for k, v in result.items(): assert isinstance(v, (int, float)), f"{k} should be an int or float" @@ -539,7 +544,22 @@ async def test_megatron_train( print("\n\n") print("fsdp results: ", results_fsdp[0]) - keys_to_compare = ["policy_loss", "policy_lr", "ppo_clip_ratio", "policy_entropy", "policy_kl", "final_loss"] + keys_to_compare = [ + "policy_loss", + "policy_lr", + "loss_metrics/clip_ratio", + "policy_entropy", + "policy_kl", + "final_loss", + ] + if ep > 1: + keys_to_compare.extend( + [ + "loss_metrics/is_ratio_mean", + "loss_metrics/outlier_seq_masked_ratio", + "loss_metrics/geo_sequence_mask_masked_ratio", + ] + ) for i, result in enumerate(results_fsdp): for k in keys_to_compare: if k == "policy_entropy": @@ -605,7 +625,7 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node) assert isinstance(result, dict), "Result should be a dictionary of training stats" assert "policy_loss" in result assert "policy_lr" in result - assert "ppo_clip_ratio" in result + assert "loss_metrics/clip_ratio" in result assert "policy_entropy" in result for k, v in result.items(): assert isinstance(v, (int, float)), f"{k} should be an int or float" @@ -641,7 +661,14 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node) print("\n\n") print("megatron results dp: ", results_megatron_dp) - keys_to_compare = ["policy_loss", "policy_lr", "ppo_clip_ratio", "policy_entropy", "policy_kl", "raw_grad_norm"] + keys_to_compare = [ + "policy_loss", + "policy_lr", + "loss_metrics/clip_ratio", + "policy_entropy", + "policy_kl", + "raw_grad_norm", + ] for i, result in enumerate(results_megatron_dp): for k in keys_to_compare: assert isinstance(result[k], (int, float)), f"{k} should be an int or float" diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py index c9880a017..429ff5ede 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py @@ -48,6 +48,11 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ cfg.trainer.algorithm.use_kl_loss = True cfg.trainer.algorithm.kl_loss_coef = 0.001 + cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "sequence" + cfg.trainer.algorithm.off_policy_correction.sequence_mask_metric = "geometric" + cfg.trainer.algorithm.off_policy_correction.geo_mask_high = 1.02 + cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 + actor_group = init_worker_with_type( "policy", shared_pg=None, @@ -74,7 +79,10 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ "policy_loss", "policy_update_steps", "policy_lr", - "ppo_clip_ratio", + "loss_metrics/clip_ratio", + "loss_metrics/is_ratio_mean", + "loss_metrics/outlier_seq_masked_ratio", + "loss_metrics/geo_sequence_mask_masked_ratio", "policy_entropy", "final_loss", ] diff --git a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py index e81103434..f72f87dfa 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py @@ -73,7 +73,7 @@ async def test_policy_forward_backward_and_optim_step(ray_init_fixture, cfg, pac for result in results: assert isinstance(result, dict), "Result should be a dictionary of training stats" assert "policy_loss" in result - assert "ppo_clip_ratio" in result + assert "loss_metrics/clip_ratio" in result assert "policy_entropy" in result for k, v in result.items(): assert isinstance(v, (int, float)), f"{k} should be an int or float" diff --git a/skyrl-train/tests/gpu/utils.py b/skyrl-train/tests/gpu/utils.py index d819604f7..0c1c3097d 100644 --- a/skyrl-train/tests/gpu/utils.py +++ b/skyrl-train/tests/gpu/utils.py @@ -77,6 +77,7 @@ def make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4) -> Traini "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), "loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "rollout_logprobs": 0.2 * torch.ones((batch_size, num_actions), device="cpu"), } ) data.metadata = {"response_length": num_actions}