From f033e65573654ddac7aa5ca605783f3693c8dd0f Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 7 Jan 2026 00:40:47 +0000 Subject: [PATCH 01/23] x --- .../skyrl_train/config/ppo_base_config.yaml | 33 +++- skyrl-train/skyrl_train/trainer.py | 13 +- skyrl-train/skyrl_train/utils/ppo_utils.py | 152 ++++++++++++++++++ skyrl-train/skyrl_train/utils/utils.py | 45 ++++++ 4 files changed, 241 insertions(+), 2 deletions(-) diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index f4dd7fbd3..f6d41fdc1 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -125,9 +125,40 @@ trainer: eps_clip_high: 0.2 # dual clip parameters clip_ratio_c: 3.0 - # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl + + # mark for deprecation tis_imp_ratio_cap: -1.0 use_tis: false + + # reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + rollout_correction: + # type of importance ratio to use for ppo loss correction + # here importance ratio refers to logprobs_{rollout} - logprobs_{policy_old} + tis_ratio_type: "null" # "null", "token", "sequence" + + # cap for the importance ratio + # 1.5-5.0 is recommended for "token" tis_ratio_type + token_tis_ratio_cap_high: 2.0 + # 2.0-10.0 is recommended for "sequence" tis_ratio_type + sequence_tis_ratio_cap_high: 5.0 + + # "sequence" mask masks out sequences with product of importance ratios outside the cap + # "geometric" mask masks out sequences with geometric mean of importance ratios outside the cap + rejection_mask_type: "null" # "null", "sequence", "geometric" + + # cap for the rejection mask ratio + # values around 0.99-1.01 are recommended for "geometric" rejection_mask_type - MoE models may need larger allowed ranges due to higher mismatch + geo_rejection_mask_ratio_cap_high: 1.001 + geo_rejection_mask_ratio_cap_low: 0.999 + + # sequence level rejection mask ratio cap + sequence_rejection_mask_ratio_cap_high: 2.0 + sequence_rejection_mask_ratio_cap_low: 0.5 + + # masks out sequences with any token having importance ratio below the cap + sequence_mask_low: 1e-4 + sequence_mask_high: 100 + # SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347) sapo: tau_pos: 1.0 diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index faa578f65..23745b1d0 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -579,12 +579,23 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks, logprobs, ) - # sanity check for tis + # sanity check for tis (legacy) if self.cfg.trainer.algorithm.use_tis: assert ( rollout_logprobs_tensor is not None ), "expected non-null rollout logprobs tensor with `trainer.algorithm.use_tis` as `True`" assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses" + + # sanity check for rollout_correction + rollout_corr = self.cfg.trainer.algorithm.rollout_correction + tis_ratio_type = rollout_corr.tis_ratio_type + rejection_mask_type = rollout_corr.rejection_mask_type + if tis_ratio_type is not None or rejection_mask_type is not None: + assert ( + rollout_logprobs_tensor is not None + ), "expected non-null rollout logprobs tensor when rollout_correction is enabled" + assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses" + training_input = TrainingInputBatch( { "sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d814aedd7..39da947e9 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -552,6 +552,152 @@ def _safe_exp_delta(delta: torch.Tensor, clip: float = 20.0, out_dtype=None) -> return y.to(out_dtype or delta.dtype) +def compute_tis_ratio( + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + tis_ratio_type: str, + rollout_corr: DictConfig, +) -> torch.Tensor: + """ + Compute truncated importance sampling (TIS) ratio for rollout correction. + + Args: + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + tis_ratio_type: Type of TIS ratio ("token" or "sequence"). + rollout_corr: Rollout correction config containing cap values. + + Returns: + TIS ratio tensor to multiply with the loss. + + Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + """ + # Compute token-level importance ratio: pi_old / pi_rollout + # In log space: old_log_probs - rollout_logprobs + token_tis_log_ratio = old_log_probs - rollout_logprobs + token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + + if tis_ratio_type == "token": + token_tis_ratio_cap = rollout_corr.token_tis_ratio_cap_high + return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap) + elif tis_ratio_type == "sequence": + # Compute sequence-level importance ratio as product of token ratios (sum of log ratios) + seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) + seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + seq_tis_ratio_cap = rollout_corr.sequence_tis_ratio_cap_high + return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap) + else: + raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") + + +def compute_rejection_mask( + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + rejection_mask_type: str, + rollout_corr: DictConfig, +) -> torch.Tensor: + """ + Compute rejection mask for rollout correction. + + This masks out sequences with importance ratios that fall outside acceptable bounds, + helping to filter out off-policy samples that may destabilize training. + + Args: + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + rejection_mask_type: Type of rejection mask ("geometric" or "sequence"). + rollout_corr: Rollout correction config containing cap values. + + Returns: + Rejection mask tensor (float) to multiply with the loss. + + Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + """ + # Compute token-level importance ratio + token_tis_log_ratio = old_log_probs - rollout_logprobs + token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + + if rejection_mask_type == "geometric": + # Compute geometric mean of importance ratios per sequence + num_tokens = loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) + seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) + geo_mean_ratio = _safe_exp_delta(seq_tis_log_ratio / num_tokens, clip=20.0, out_dtype=old_log_probs.dtype) + geo_cap_high = rollout_corr.geo_rejection_mask_ratio_cap_high + geo_cap_low = rollout_corr.geo_rejection_mask_ratio_cap_low + geo_rejection_mask = (geo_mean_ratio >= geo_cap_low) & (geo_mean_ratio <= geo_cap_high) + return geo_rejection_mask.float() + elif rejection_mask_type == "sequence": + # Mask out sequences with product of importance ratios outside the cap + seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) + seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + seq_cap_high = rollout_corr.sequence_rejection_mask_ratio_cap_high + seq_cap_low = rollout_corr.sequence_rejection_mask_ratio_cap_low + # Also check per-token bounds + token_mask_low = rollout_corr.sequence_mask_low + token_mask_high = rollout_corr.sequence_mask_high + token_in_bounds = (token_tis_ratio >= token_mask_low) & (token_tis_ratio <= token_mask_high) + # A sequence is valid if all tokens are in bounds (considering only masked positions) + all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True) + seq_in_bounds = (seq_tis_ratio >= seq_cap_low) & (seq_tis_ratio <= seq_cap_high) + seq_rejection_mask = seq_in_bounds & all_tokens_valid + return seq_rejection_mask.float() + else: + raise ValueError(f"Unknown rejection_mask_type: {rejection_mask_type}") + + +def apply_rollout_correction( + loss: torch.Tensor, + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + rollout_corr: DictConfig, +) -> torch.Tensor: + """ + Apply rollout correction to the loss using TIS ratio and rejection mask. + + This is a convenience function that combines compute_tis_ratio and compute_rejection_mask. + + Args: + loss: The loss tensor to correct. + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + rollout_corr: Rollout correction config. + + Returns: + Corrected loss tensor. + """ + tis_ratio_type = rollout_corr.tis_ratio_type + rejection_mask_type = rollout_corr.rejection_mask_type + + # Check if TIS ratio correction is enabled + apply_tis = tis_ratio_type is not None and tis_ratio_type != "null" + # Check if rejection mask is enabled + apply_rejection = rejection_mask_type is not None and rejection_mask_type != "null" + + # Early return if no correction needed + if not apply_tis and not apply_rejection: + return loss + + # Apply TIS ratio if enabled + if apply_tis: + tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr) + loss = loss * tis_ratio + + # Apply rejection mask if enabled + if apply_rejection: + rejection_mask = compute_rejection_mask( + old_log_probs, rollout_logprobs, loss_mask, rejection_mask_type, rollout_corr + ) + loss = loss * rejection_mask + + return loss + + @register_policy_loss(PolicyLossType.REGULAR) @register_policy_loss(PolicyLossType.DUAL_CLIP) def ppo_policy_loss( @@ -581,6 +727,7 @@ def ppo_policy_loss( clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + # Legacy TIS support (deprecated) if config.use_tis: from loguru import logger as logger_ @@ -590,6 +737,11 @@ def ppo_policy_loss( tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) loss = loss * tis_imp_ratio + # New rollout correction support + rollout_corr = config.get("rollout_correction", None) + if rollout_corr is not None and rollout_logprobs is not None: + loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) return loss, clip_ratio diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 00d184c92..94bda59cf 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -295,7 +295,11 @@ def validate_cfg(cfg: DictConfig): "please set `offload_after_step` to `true` for both policy and critic" ) + # Legacy TIS validation (deprecated) if cfg.trainer.algorithm.use_tis: + logger.warning( + "`trainer.algorithm.use_tis` is deprecated. Please use `trainer.algorithm.rollout_correction` instead." + ) if cfg.trainer.algorithm.tis_imp_ratio_cap <= 0: raise ValueError( f"If `trainer.algorithm.use_tis` is `True` then `cfg.trainer.algorithm.tis_imp_ratio_cap` " @@ -316,6 +320,47 @@ def validate_cfg(cfg: DictConfig): "dual_clip", ], "TIS is only implemented for regular and dual_clip policy loss types" + # New rollout_correction validation + rollout_corr = cfg.trainer.algorithm.get("rollout_correction", None) + if rollout_corr is not None: + tis_ratio_type = rollout_corr.get("tis_ratio_type", "null") + rejection_mask_type = rollout_corr.get("rejection_mask_type", "null") + + uses_rollout_correction = tis_ratio_type != "null" or rejection_mask_type != "null" + + if uses_rollout_correction: + # Validate tis_ratio_type + assert tis_ratio_type in [ + "null", + "token", + "sequence", + ], f"`tis_ratio_type` must be 'null', 'token', or 'sequence', got {tis_ratio_type}" + + # Validate rejection_mask_type + assert rejection_mask_type in [ + "null", + "sequence", + "geometric", + ], f"`rejection_mask_type` must be 'null', 'sequence', or 'geometric', got {rejection_mask_type}" + + # Ensure logprobs are enabled for rollout correction + if cfg.generator.sampling_params.logprobs is None: + logger.warning( + "`generator.sampling_params.logprobs` is `None` but rollout_correction is enabled." + " Setting `logprobs` to `True`." + ) + cfg.generator.sampling_params.logprobs = 0 + + if cfg.generator.backend == "sglang": + raise NotImplementedError( + "`trainer.algorithm.rollout_correction` doesn't support Sglang backend, please use vLLM" + ) + + assert cfg.trainer.algorithm.policy_loss_type in [ + "regular", + "dual_clip", + ], "rollout_correction is only implemented for regular and dual_clip policy loss types" + if cfg.trainer.policy.model.lora.rank > 0: # LoRA enabled # Right now: assert generator backend must be vllm, training backend must be fsdp/fsdp2 From 3f3b75942ab23687001cf2a2a322d68d3ae7129b Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 7 Jan 2026 00:50:10 +0000 Subject: [PATCH 02/23] x --- skyrl-train/docs/examples/flash_rl.rst | 5 +++-- .../examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh | 6 +++--- skyrl-train/examples/fully_async/async_run_gsm8k.sh | 6 +++--- .../examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh | 6 +++--- .../megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh | 6 +++--- skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh | 6 +++--- .../examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh | 5 +++-- .../examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh | 6 +++--- skyrl-train/examples/search/run_search.sh | 6 +++--- .../examples/search/run_search_conversation_format.sh | 6 +++--- .../examples/text_to_sql/run_skyrl_sql_megatron_lora.sh | 6 +++--- skyrl-train/examples/tis_correction/run_dapo_tis.sh | 6 +++--- 12 files changed, 36 insertions(+), 34 deletions(-) diff --git a/skyrl-train/docs/examples/flash_rl.rst b/skyrl-train/docs/examples/flash_rl.rst index a34c18ef3..b75292dad 100644 --- a/skyrl-train/docs/examples/flash_rl.rst +++ b/skyrl-train/docs/examples/flash_rl.rst @@ -60,12 +60,13 @@ We highlight some important training parameters configured for FlashRL from our DATA_DIR="$HOME/data/dapo" # TIS parameters - USE_TIS=true + TIS_TYPE=token TIS_IMP_RATIO_CAP=8.0 uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 -m examples.flash_rl.main_dapo_flashrl \ ... - trainer.algorithm.use_tis=$USE_TIS \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ generator.sampling_params.logprobs=0 \ ... diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh index bf4a033f4..7b0122c6e 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh @@ -33,7 +33,7 @@ MAX_RESPONSE_LENGTH=20480 MAX_PROMPT_LENGTH=2048 TIS_IMP_RATIO_CAP=8.0 -USE_TIS=true +TIS_TYPE=token LOGPROBS=0 CKPT_PATH="$HOME/ckpts/dapo_32b_ckpt" @@ -57,8 +57,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-32B" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/fully_async/async_run_gsm8k.sh b/skyrl-train/examples/fully_async/async_run_gsm8k.sh index 5a4404b3b..a26e4dd87 100644 --- a/skyrl-train/examples/fully_async/async_run_gsm8k.sh +++ b/skyrl-train/examples/fully_async/async_run_gsm8k.sh @@ -28,7 +28,7 @@ set -x : "${MAX_STALENESS_STEPS:=4}" : "${NUM_PARALLEL_GENERATION_WORKERS:=$(( MINI_BATCH_SIZE * (MAX_STALENESS_STEPS + 1) ))}" -USE_TIS=true +TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 RUN_NAME=gsm8k-async-qwen2.5_1.5B-useTIS_${USE_TIS}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen @@ -39,8 +39,8 @@ uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \ trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \ trainer.algorithm.advantage_estimator="grpo" \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=false \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh index d2af35b5b..584234160 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh @@ -53,7 +53,7 @@ MEGATRON_ETP=1 # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -83,8 +83,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh index 7ae14bc4d..904343576 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh @@ -56,7 +56,7 @@ LORA_ALPHA=64 # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -88,8 +88,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh index 014ee567f..8e9393445 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh @@ -50,7 +50,7 @@ MEGATRON_ETP=null # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -80,8 +80,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh index 25dc8e6d3..f690305e6 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh @@ -55,7 +55,7 @@ LORA_A_INIT_METHOD="kaiming" # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -85,7 +85,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=2.0 \ trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ diff --git a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh index 6c2f3a899..b27f8f869 100644 --- a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh +++ b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh @@ -38,7 +38,7 @@ LORA_A_INIT_METHOD="kaiming" # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ @@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.ref.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.ref.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \ diff --git a/skyrl-train/examples/search/run_search.sh b/skyrl-train/examples/search/run_search.sh index 1703e455e..3203525b1 100755 --- a/skyrl-train/examples/search/run_search.sh +++ b/skyrl-train/examples/search/run_search.sh @@ -11,7 +11,7 @@ DATA_DIR="$HOME/data/searchR1" RUN_NAME="skyrl-search_4turns_maxgeneratelen_500-multiturn-sync-TIS_2.0" -USE_TIS=true +TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ @@ -23,8 +23,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.num_warmup_steps=94 \ trainer.algorithm.use_kl_loss=true \ trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/search/run_search_conversation_format.sh b/skyrl-train/examples/search/run_search_conversation_format.sh index 9346efc13..3679edb1e 100755 --- a/skyrl-train/examples/search/run_search_conversation_format.sh +++ b/skyrl-train/examples/search/run_search_conversation_format.sh @@ -18,7 +18,7 @@ DATA_DIR="$HOME/data/searchR1" RUN_NAME="skyrl-search_4turns_maxgeneratelen_500" -USE_TIS=true +TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ @@ -30,8 +30,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.num_warmup_steps=94 \ trainer.algorithm.use_kl_loss=true \ trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh index f33cb7e65..eaf4dd4fe 100644 --- a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh +++ b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh @@ -32,7 +32,7 @@ MEGATRON_ETP=null # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.algorithm.advantage_estimator="grpo" \ @@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.lr=3.0e-5 \ trainer.policy_mini_batch_size=256 \ trainer.algorithm.use_kl_loss=false \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.ckpt_interval=60 \ trainer.hf_save_interval=30 \ trainer.dump_data_batch=true \ diff --git a/skyrl-train/examples/tis_correction/run_dapo_tis.sh b/skyrl-train/examples/tis_correction/run_dapo_tis.sh index a0995dd3c..04bc4cd8a 100644 --- a/skyrl-train/examples/tis_correction/run_dapo_tis.sh +++ b/skyrl-train/examples/tis_correction/run_dapo_tis.sh @@ -12,7 +12,7 @@ LOGGER="wandb" # change to "console" to print to stdout # TIS parameters TIS_IMP_RATIO_CAP=2.0 -USE_TIS=true +TIS_TYPE=token # returns rollout logprobs for the generated tokens; required for TIS LOGPROBS=0 @@ -55,8 +55,8 @@ uv run --isolated --extra vllm -m examples.tis_correction.main_tis_dapo \ generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=$USE_TIS \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ From 29efd6f2d67c81cf52b765389320daa948bd31c0 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 7 Jan 2026 01:09:05 +0000 Subject: [PATCH 03/23] x --- .../skyrl_train/config/ppo_base_config.yaml | 2 +- .../tests/cpu/algorithms/test_losses.py | 470 +++++++++++++++++- 2 files changed, 470 insertions(+), 2 deletions(-) diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index e5cfc74d4..3d501f00e 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -128,7 +128,7 @@ trainer: # reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md rollout_correction: # type of importance ratio to use for ppo loss correction - # here importance ratio refers to logprobs_{rollout} - logprobs_{policy_old} + # here importance ratio refers to logprobs_{policy_old} - logprobs_{rollout_policy} tis_ratio_type: "null" # "null", "token", "sequence" # cap for the importance ratio diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index f5904b595..3db96b702 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -8,7 +8,13 @@ import torch from omegaconf import DictConfig -from skyrl_train.utils.ppo_utils import PolicyLossRegistry, masked_mean +from skyrl_train.utils.ppo_utils import ( + PolicyLossRegistry, + masked_mean, + compute_tis_ratio, + compute_rejection_mask, + apply_rollout_correction, +) # Adapted a good test from NeMO-RL @@ -621,3 +627,465 @@ def gate_function(x, tau): # SAPO should always report clip_ratio = 0.0 assert actual_clip_ratio == 0.0 + + +# ============================================================================ +# Rollout Correction Tests +# ============================================================================ + + +def test_compute_tis_ratio_token_level(): + """Tests token-level TIS ratio computation with capping.""" + device = "cpu" + + # old_log_probs - rollout_logprobs gives the log importance ratio + # Token ratios: exp([0.5, -0.5, 1.0]) = [1.6487, 0.6065, 2.7183] + old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "token", + "token_tis_ratio_cap_high": 2.0, + } + ) + + tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "token", config) + + # Expected: [1.6487, 0.6065, 2.0] (third token capped at 2.0) + expected = torch.tensor([[1.6487, 0.6065, 2.0]], device=device) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_tis_ratio_sequence_level(): + """Tests sequence-level TIS ratio computation with capping.""" + device = "cpu" + + # Token log ratios: [0.5, -0.5, 1.0] + # Sequence log ratio (sum of masked): 0.5 + (-0.5) + 1.0 = 1.0 + # Sequence ratio: exp(1.0) = 2.7183 + old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "sequence", + "sequence_tis_ratio_cap_high": 5.0, + } + ) + + tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Expected: exp(1.0) = 2.7183, shape [batch, 1] for sequence-level + expected = torch.tensor([[2.7183]], device=device) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_tis_ratio_sequence_level_with_cap(): + """Tests sequence-level TIS ratio capping.""" + device = "cpu" + + # Token log ratios: [1.0, 1.0, 1.0] + # Sequence log ratio: 3.0 + # Sequence ratio: exp(3.0) = 20.09, should be capped at 5.0 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "sequence", + "sequence_tis_ratio_cap_high": 5.0, + } + ) + + tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Expected: capped at 5.0, shape [batch, 1] for sequence-level + expected = torch.tensor([[5.0]], device=device) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_tis_ratio_with_mask(): + """Tests that loss_mask correctly excludes tokens from sequence-level computation.""" + device = "cpu" + + # Token log ratios: [0.5, -0.5, 1.0] + # With mask [1, 0, 1], sequence log ratio = 0.5 + 1.0 = 1.5 + old_log_probs = torch.tensor([[-1.0, -1.5, -0.5]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 0.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "sequence", + "sequence_tis_ratio_cap_high": 10.0, + } + ) + + tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Expected: exp(1.5) = 4.4817, shape [batch, 1] for sequence-level + expected_val = torch.exp(torch.tensor(1.5)) + expected = expected_val.reshape(1, 1) + torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_rejection_mask_geometric(): + """Tests geometric rejection mask computation.""" + device = "cpu" + + # Token log ratios: [0.1, -0.1, 0.0] -> sum = 0.0, geometric mean = exp(0/3) = 1.0 + old_log_probs = torch.tensor([[-1.0, -1.1, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.1, -1.0, -1.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "rejection_mask_type": "geometric", + "geo_rejection_mask_ratio_cap_high": 1.1, + "geo_rejection_mask_ratio_cap_low": 0.9, + } + ) + + rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + + # Geometric mean ≈ 1.0, which is within [0.9, 1.1], so mask should be 1.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_rejection_mask_geometric_rejects(): + """Tests geometric rejection mask correctly rejects sequences outside bounds.""" + device = "cpu" + + # Token log ratios: [0.5, 0.5, 0.5] -> sum = 1.5, geometric mean = exp(1.5/3) = exp(0.5) ≈ 1.6487 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "rejection_mask_type": "geometric", + "geo_rejection_mask_ratio_cap_high": 1.1, + "geo_rejection_mask_ratio_cap_low": 0.9, + } + ) + + rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + + # Geometric mean ≈ 1.6487, which is outside [0.9, 1.1], so mask should be 0.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[0.0]], device=device) + torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_rejection_mask_sequence(): + """Tests sequence rejection mask computation.""" + device = "cpu" + + # Token log ratios: [0.2, 0.1, 0.0] -> sum = 0.3, seq ratio = exp(0.3) ≈ 1.35 + old_log_probs = torch.tensor([[-1.0, -1.1, -1.2]], device=device) + rollout_logprobs = torch.tensor([[-1.2, -1.2, -1.2]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "rejection_mask_type": "sequence", + "sequence_rejection_mask_ratio_cap_high": 2.0, + "sequence_rejection_mask_ratio_cap_low": 0.5, + "sequence_mask_low": 1e-4, + "sequence_mask_high": 100.0, + } + ) + + rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Sequence ratio ≈ 1.35, which is within [0.5, 2.0], and all token ratios are in bounds + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_rejection_mask_sequence_rejects_by_seq_ratio(): + """Tests sequence rejection mask rejects when sequence ratio is out of bounds.""" + device = "cpu" + + # Token log ratios: [1.0, 1.0, 1.0] -> sum = 3.0, seq ratio = exp(3.0) ≈ 20.09 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "rejection_mask_type": "sequence", + "sequence_rejection_mask_ratio_cap_high": 2.0, + "sequence_rejection_mask_ratio_cap_low": 0.5, + "sequence_mask_low": 1e-4, + "sequence_mask_high": 100.0, + } + ) + + rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Sequence ratio ≈ 20.09, which is outside [0.5, 2.0], so mask should be 0.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[0.0]], device=device) + torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_rejection_mask_sequence_rejects_by_token_bounds(): + """Tests sequence rejection mask rejects when a token ratio is out of bounds.""" + device = "cpu" + + # Token log ratios: [0.0, 0.0, 5.0] -> token ratios = [1.0, 1.0, 148.4] + # Sequence ratio = exp(5.0) ≈ 148.4 (in bounds if cap is high enough) + # But third token ratio 148.4 > 100.0, so should reject + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.0, -1.0, -6.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "rejection_mask_type": "sequence", + "sequence_rejection_mask_ratio_cap_high": 200.0, # High enough to not reject by seq ratio + "sequence_rejection_mask_ratio_cap_low": 0.01, + "sequence_mask_low": 1e-4, + "sequence_mask_high": 100.0, # This should cause rejection + } + ) + + rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + + # Token ratio 148.4 > 100.0, so mask should be 0.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[0.0]], device=device) + torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + + +def test_apply_rollout_correction_null_configs(): + """Tests that apply_rollout_correction returns loss unchanged when both configs are null.""" + device = "cpu" + + loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "null", + "rejection_mask_type": "null", + } + ) + + corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + + # Should return the same tensor (early return) + assert corrected_loss is loss + + +def test_apply_rollout_correction_tis_only(): + """Tests apply_rollout_correction with only TIS enabled.""" + device = "cpu" + + loss = torch.tensor([[1.0, 1.0, 1.0]], device=device) + # Token log ratios: [0.5, 0.5, 0.5] -> token ratios = [1.6487, 1.6487, 1.6487] + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "token", + "token_tis_ratio_cap_high": 2.0, + "rejection_mask_type": "null", + } + ) + + corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + + # Expected: loss * 1.6487 (no capping needed) + expected = loss * torch.exp(torch.tensor(0.5)) + torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + + +def test_apply_rollout_correction_rejection_only(): + """Tests apply_rollout_correction with only rejection mask enabled.""" + device = "cpu" + + loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) + # Token log ratios: [0.0, 0.0, 0.0] -> geometric mean = 1.0 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "null", + "rejection_mask_type": "geometric", + "geo_rejection_mask_ratio_cap_high": 1.1, + "geo_rejection_mask_ratio_cap_low": 0.9, + } + ) + + corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + + # Geometric mean = 1.0, within bounds, so loss unchanged + torch.testing.assert_close(corrected_loss, loss, rtol=1e-3, atol=1e-4) + + +def test_apply_rollout_correction_both_enabled(): + """Tests apply_rollout_correction with both TIS and rejection mask enabled.""" + device = "cpu" + + loss = torch.tensor([[1.0, 1.0, 1.0]], device=device) + # Token log ratios: [0.1, 0.1, 0.1] -> token ratios ≈ [1.105, 1.105, 1.105] + # Geometric mean ≈ 1.105 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.1, -1.1, -1.1]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "token", + "token_tis_ratio_cap_high": 2.0, + "rejection_mask_type": "geometric", + "geo_rejection_mask_ratio_cap_high": 1.2, + "geo_rejection_mask_ratio_cap_low": 0.8, + } + ) + + corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + + # TIS ratio ≈ 1.105, geometric mean ≈ 1.105 (within bounds, mask=1) + # Expected: loss * 1.105 * 1.0 = loss * 1.105 + expected = loss * torch.exp(torch.tensor(0.1)) + torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + + +def test_apply_rollout_correction_rejection_zeros_loss(): + """Tests that rejection mask can zero out the loss entirely.""" + device = "cpu" + + loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) + # Token log ratios: [1.0, 1.0, 1.0] -> geometric mean = exp(1.0) ≈ 2.718 + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "tis_ratio_type": "null", + "rejection_mask_type": "geometric", + "geo_rejection_mask_ratio_cap_high": 1.1, + "geo_rejection_mask_ratio_cap_low": 0.9, + } + ) + + corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + + # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss should be zeroed + expected = torch.tensor([[0.0, 0.0, 0.0]], device=device) + torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + + +def test_ppo_policy_loss_with_rollout_correction(): + """Integration test for PPO policy loss with rollout correction enabled.""" + device = "cpu" + + advantages = torch.tensor([[1.0, -1.0, 0.5]], device=device) + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + log_probs = torch.tensor([[-1.1, -0.9, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.05, -1.05, -1.05]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "eps_clip_low": 0.2, + "eps_clip_high": 0.2, + "policy_loss_type": "regular", + "loss_reduction": "token_mean", + "max_seq_len": 4, + "use_tis": False, + "rollout_correction": { + "tis_ratio_type": "token", + "token_tis_ratio_cap_high": 2.0, + "rejection_mask_type": "null", + }, + } + ) + + loss_fn = PolicyLossRegistry.get("regular") + + # Loss with rollout correction + loss_with_correction, _ = loss_fn( + log_probs=log_probs, + old_log_probs=old_log_probs, + advantages=advantages, + config=config, + loss_mask=loss_mask, + rollout_logprobs=rollout_logprobs, + ) + + # Loss without rollout correction + config_no_correction = DictConfig( + { + "eps_clip_low": 0.2, + "eps_clip_high": 0.2, + "policy_loss_type": "regular", + "loss_reduction": "token_mean", + "max_seq_len": 4, + "use_tis": False, + "rollout_correction": { + "tis_ratio_type": "null", + "rejection_mask_type": "null", + }, + } + ) + + loss_without_correction, _ = loss_fn( + log_probs=log_probs, + old_log_probs=old_log_probs, + advantages=advantages, + config=config_no_correction, + loss_mask=loss_mask, + rollout_logprobs=rollout_logprobs, + ) + + # TIS correction should modify the loss + assert not torch.allclose(loss_with_correction, loss_without_correction, rtol=1e-3), ( + f"Rollout correction should change the loss: " + f"with={loss_with_correction:.6f} vs without={loss_without_correction:.6f}" + ) + + +def test_compute_tis_ratio_invalid_type(): + """Tests that compute_tis_ratio raises error for invalid tis_ratio_type.""" + device = "cpu" + + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig({"tis_ratio_type": "invalid"}) + + with pytest.raises(ValueError, match="Unknown tis_ratio_type"): + compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "invalid", config) + + +def test_compute_rejection_mask_invalid_type(): + """Tests that compute_rejection_mask raises error for invalid rejection_mask_type.""" + device = "cpu" + + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig({"rejection_mask_type": "invalid"}) + + with pytest.raises(ValueError, match="Unknown rejection_mask_type"): + compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "invalid", config) From 1520157103b98282551b2823802eb766b52a5fdf Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 7 Jan 2026 22:10:13 +0000 Subject: [PATCH 04/23] x --- .../examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh index f690305e6..9486a9d9f 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh @@ -86,8 +86,7 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=2.0 \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \ From 45a59c2ae3736eb1644751c5f7ef9d6d3d2ac8dc Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Wed, 7 Jan 2026 23:31:39 +0000 Subject: [PATCH 05/23] x --- .../skyrl_train/config/ppo_base_config.yaml | 6 +- skyrl-train/skyrl_train/utils/ppo_utils.py | 58 ++++++++++++--- .../tests/cpu/algorithms/test_losses.py | 74 +++++++++++++++---- 3 files changed, 108 insertions(+), 30 deletions(-) diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 3d501f00e..2234aca5c 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -150,9 +150,9 @@ trainer: sequence_rejection_mask_ratio_cap_high: 2.0 sequence_rejection_mask_ratio_cap_low: 0.5 - # masks out sequences with any token having importance ratio below the cap - sequence_mask_low: 1e-4 - sequence_mask_high: 100 + # masks out sequences with any token having importance ration far outside an acceptable range + outlier_token_is_threshold_low: 1e-4 + outlier_token_is_threshold_high: 100 # SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347) sapo: diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 39da947e9..d4c77faaa 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -592,6 +592,44 @@ def compute_tis_ratio( raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") +def compute_outlier_token_mask( + old_log_probs: torch.Tensor, + rollout_logprobs: torch.Tensor, + loss_mask: torch.Tensor, + rollout_corr: DictConfig, +) -> torch.Tensor: + """ + Compute outlier token mask that rejects sequences with any token having + importance ratio outside acceptable bounds. + + This is applied independently of TIS ratio type or rejection mask type, + whenever rollout correction is enabled. + + Args: + old_log_probs: Log probabilities from the old policy (before update). + rollout_logprobs: Log probabilities from the rollout policy. + loss_mask: Mask indicating valid tokens. + rollout_corr: Rollout correction config containing threshold values. + + Returns: + Outlier token mask tensor (float) to multiply with the loss. + Shape is [batch, 1] (sequence-level mask). + """ + # Compute token-level importance ratio + token_tis_log_ratio = old_log_probs - rollout_logprobs + token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + + # Check per-token bounds + token_mask_low = rollout_corr.outlier_token_is_threshold_low + token_mask_high = rollout_corr.outlier_token_is_threshold_high + token_in_bounds = (token_tis_ratio >= token_mask_low) & (token_tis_ratio <= token_mask_high) + + # A sequence is valid if all tokens are in bounds (considering only masked positions) + all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True) + + return all_tokens_valid.float() + + def compute_rejection_mask( old_log_probs: torch.Tensor, rollout_logprobs: torch.Tensor, @@ -619,7 +657,6 @@ def compute_rejection_mask( """ # Compute token-level importance ratio token_tis_log_ratio = old_log_probs - rollout_logprobs - token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) if rejection_mask_type == "geometric": # Compute geometric mean of importance ratios per sequence @@ -636,15 +673,8 @@ def compute_rejection_mask( seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) seq_cap_high = rollout_corr.sequence_rejection_mask_ratio_cap_high seq_cap_low = rollout_corr.sequence_rejection_mask_ratio_cap_low - # Also check per-token bounds - token_mask_low = rollout_corr.sequence_mask_low - token_mask_high = rollout_corr.sequence_mask_high - token_in_bounds = (token_tis_ratio >= token_mask_low) & (token_tis_ratio <= token_mask_high) - # A sequence is valid if all tokens are in bounds (considering only masked positions) - all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True) seq_in_bounds = (seq_tis_ratio >= seq_cap_low) & (seq_tis_ratio <= seq_cap_high) - seq_rejection_mask = seq_in_bounds & all_tokens_valid - return seq_rejection_mask.float() + return seq_in_bounds.float() else: raise ValueError(f"Unknown rejection_mask_type: {rejection_mask_type}") @@ -657,9 +687,10 @@ def apply_rollout_correction( rollout_corr: DictConfig, ) -> torch.Tensor: """ - Apply rollout correction to the loss using TIS ratio and rejection mask. + Apply rollout correction to the loss using TIS ratio, rejection mask, and outlier token mask. - This is a convenience function that combines compute_tis_ratio and compute_rejection_mask. + This is a convenience function that combines compute_tis_ratio, compute_rejection_mask, + and compute_outlier_token_mask. Args: loss: The loss tensor to correct. @@ -683,6 +714,11 @@ def apply_rollout_correction( if not apply_tis and not apply_rejection: return loss + # Apply outlier token mask whenever rollout correction is enabled + # This rejects sequences with any token having importance ratio outside acceptable bounds + outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss = loss * outlier_mask + # Apply TIS ratio if enabled if apply_tis: tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr) diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index 3db96b702..d82b62af8 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -13,6 +13,7 @@ masked_mean, compute_tis_ratio, compute_rejection_mask, + compute_outlier_token_mask, apply_rollout_correction, ) @@ -797,14 +798,12 @@ def test_compute_rejection_mask_sequence(): "rejection_mask_type": "sequence", "sequence_rejection_mask_ratio_cap_high": 2.0, "sequence_rejection_mask_ratio_cap_low": 0.5, - "sequence_mask_low": 1e-4, - "sequence_mask_high": 100.0, } ) rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) - # Sequence ratio ≈ 1.35, which is within [0.5, 2.0], and all token ratios are in bounds + # Sequence ratio ≈ 1.35, which is within [0.5, 2.0] # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[1.0]], device=device) torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) @@ -824,8 +823,6 @@ def test_compute_rejection_mask_sequence_rejects_by_seq_ratio(): "rejection_mask_type": "sequence", "sequence_rejection_mask_ratio_cap_high": 2.0, "sequence_rejection_mask_ratio_cap_low": 0.5, - "sequence_mask_low": 1e-4, - "sequence_mask_high": 100.0, } ) @@ -837,33 +834,78 @@ def test_compute_rejection_mask_sequence_rejects_by_seq_ratio(): torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) -def test_compute_rejection_mask_sequence_rejects_by_token_bounds(): - """Tests sequence rejection mask rejects when a token ratio is out of bounds.""" +def test_compute_outlier_token_mask_rejects_by_token_bounds(): + """Tests outlier token mask rejects when a token ratio is out of bounds.""" device = "cpu" # Token log ratios: [0.0, 0.0, 5.0] -> token ratios = [1.0, 1.0, 148.4] - # Sequence ratio = exp(5.0) ≈ 148.4 (in bounds if cap is high enough) - # But third token ratio 148.4 > 100.0, so should reject + # Third token ratio 148.4 > 100.0, so should reject old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) rollout_logprobs = torch.tensor([[-1.0, -1.0, -6.0]], device=device) loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) config = DictConfig( { - "rejection_mask_type": "sequence", - "sequence_rejection_mask_ratio_cap_high": 200.0, # High enough to not reject by seq ratio - "sequence_rejection_mask_ratio_cap_low": 0.01, - "sequence_mask_low": 1e-4, - "sequence_mask_high": 100.0, # This should cause rejection + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, # This should cause rejection } ) - rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) # Token ratio 148.4 > 100.0, so mask should be 0.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[0.0]], device=device) - torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_outlier_token_mask_accepts_in_bounds(): + """Tests outlier token mask accepts when all token ratios are in bounds.""" + device = "cpu" + + # Token log ratios: [0.5, -0.5, 0.0] -> token ratios = [1.65, 0.61, 1.0] + # All token ratios within [1e-4, 100.0], so should accept + old_log_probs = torch.tensor([[-1.0, -1.5, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.5, -1.0, -1.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) + + config = DictConfig( + { + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + + # All token ratios in bounds, so mask should be 1.0 + # Shape is [batch, 1] for sequence-level mask + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + + +def test_compute_outlier_token_mask_respects_loss_mask(): + """Tests outlier token mask ignores out-of-bounds tokens that are masked.""" + device = "cpu" + + # Token log ratios: [0.0, 0.0, 5.0] -> token ratios = [1.0, 1.0, 148.4] + # Third token ratio 148.4 > 100.0, but it's masked, so should accept + old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + rollout_logprobs = torch.tensor([[-1.0, -1.0, -6.0]], device=device) + loss_mask = torch.tensor([[1.0, 1.0, 0.0]], device=device) # Third token masked + + config = DictConfig( + { + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, + } + ) + + outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + + # Third token is masked, so even though ratio is out of bounds, sequence should be accepted + expected = torch.tensor([[1.0]], device=device) + torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) def test_apply_rollout_correction_null_configs(): From ce01bb26503c773710b4f6476736c469d342c436 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 00:00:35 +0000 Subject: [PATCH 06/23] fix tests and add rollout correction to other loss types --- skyrl-train/docs/examples/flash_rl.rst | 1 - skyrl-train/skyrl_train/utils/ppo_utils.py | 35 +++++++++++++++---- skyrl-train/skyrl_train/utils/utils.py | 4 --- .../tests/cpu/algorithms/test_losses.py | 29 +++++++++++++++ 4 files changed, 58 insertions(+), 11 deletions(-) diff --git a/skyrl-train/docs/examples/flash_rl.rst b/skyrl-train/docs/examples/flash_rl.rst index b75292dad..acec33b38 100644 --- a/skyrl-train/docs/examples/flash_rl.rst +++ b/skyrl-train/docs/examples/flash_rl.rst @@ -67,7 +67,6 @@ We highlight some important training parameters configured for FlashRL from our ... trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ - trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \ generator.sampling_params.logprobs=0 \ ... diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index d4c77faaa..74308e0d1 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -652,8 +652,6 @@ def compute_rejection_mask( Returns: Rejection mask tensor (float) to multiply with the loss. - - Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md """ # Compute token-level importance ratio token_tis_log_ratio = old_log_probs - rollout_logprobs @@ -701,6 +699,10 @@ def apply_rollout_correction( Returns: Corrected loss tensor. + + References: + - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + - https://fengyao.notion.site/off-policy-rl """ tis_ratio_type = rollout_corr.tis_ratio_type rejection_mask_type = rollout_corr.rejection_mask_type @@ -773,10 +775,9 @@ def ppo_policy_loss( tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) loss = loss * tis_imp_ratio - # New rollout correction support - rollout_corr = config.get("rollout_correction", None) - if rollout_corr is not None and rollout_logprobs is not None: - loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + # apply rollout correction + rollout_corr = config.rollout_correction + loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) return loss, clip_ratio @@ -841,6 +842,10 @@ def gate_function(x, tau): # compute policy gradient loss loss = -gates * advantages + # apply rollout correction + rollout_corr = config.rollout_correction + loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) @@ -902,6 +907,11 @@ def gspo_policy_loss( surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages loss = -torch.min(surr1, surr2) + # apply rollout correction + rollout_corr = config.rollout_correction + if rollout_corr is not None and rollout_logprobs is not None: + loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + # Compute clipping ratio for monitoring clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() @@ -935,6 +945,10 @@ def compute_policy_loss_cispo( is_clipped = (ratio < 1 - config.cispo.cispo_eps_clip_low) | (ratio > 1 + config.cispo.cispo_eps_clip_high) clip_ratio = masked_mean(is_clipped.float(), loss_mask).mean().detach().item() + # apply rollout correction + rollout_corr = config.rollout_correction + loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) return loss, clip_ratio @@ -996,6 +1010,11 @@ def compute_policy_loss_clip_cov( # Apply correction mask to losses pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr + + # apply rollout correction + rollout_corr = config.rollout_correction + pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + pg_loss = reduce_loss( loss=pg_losses, loss_mask=loss_mask, @@ -1055,6 +1074,10 @@ def compute_policy_loss_kl_cov( large_cov_idxs % advantages.shape[1], ] + # apply rollout correction + rollout_corr = config.rollout_correction + pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + pg_loss = reduce_loss( loss=pg_losses, loss_mask=loss_mask, diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index a87bf046f..7051d8671 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -301,10 +301,6 @@ def validate_cfg(cfg: DictConfig): if cfg.generator.backend == "sglang": raise NotImplementedError("`trainer.algorithm.use_tis` doesn't support Sglang backend, please use vLLM") - assert cfg.trainer.algorithm.policy_loss_type in [ - "regular", - "dual_clip", - ], "TIS is only implemented for regular and dual_clip policy loss types" # New rollout_correction validation rollout_corr = cfg.trainer.algorithm.get("rollout_correction", None) diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index d82b62af8..f026bc5aa 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -17,6 +17,12 @@ apply_rollout_correction, ) +NULL_ROLLOUT_CORR = { + "tis_ratio_type": "null", + "rejection_mask_type": "null", + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, +} # Adapted a good test from NeMO-RL def test_policy_loss_dual_clip(): @@ -41,6 +47,7 @@ def test_policy_loss_dual_clip(): "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -94,6 +101,7 @@ def test_policy_loss_cispo(): "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -172,6 +180,7 @@ def test_policy_loss_reduction_modes(): "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -184,6 +193,7 @@ def test_policy_loss_reduction_modes(): "loss_reduction": "sequence_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -257,6 +267,7 @@ def test_policy_loss_reduction_edge_cases(): "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -269,6 +280,7 @@ def test_policy_loss_reduction_edge_cases(): "loss_reduction": "sequence_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -355,6 +367,7 @@ def test_gspo_importance_sampling_levels(): "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) ppo_loss_fn = PolicyLossRegistry.get("regular") @@ -370,6 +383,7 @@ def test_gspo_importance_sampling_levels(): "loss_reduction": "sequence_mean", # GSPO recommended reduction "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) gspo_loss_fn = PolicyLossRegistry.get("gspo") @@ -476,6 +490,7 @@ def test_clip_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "clip_cov": {"clip_ratio": 0.5, "clip_cov_lb": -5.0, "clip_cov_ub": 5.0}, # Large ratio for testing + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -498,6 +513,7 @@ def test_clip_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -538,6 +554,7 @@ def test_kl_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "kl_cov": {"kl_cov_frac": 0.5, "ppo_kl_coef": 1.0}, # Apply KL to 50% of tokens + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -560,6 +577,7 @@ def test_kl_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "use_tis": False, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -592,6 +610,7 @@ def test_sapo_policy_loss_basic(): "loss_reduction": "sequence_mean", "max_seq_len": 4, "sapo": {"tau_pos": 1.0, "tau_neg": 2.0}, + "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -945,6 +964,8 @@ def test_apply_rollout_correction_tis_only(): "tis_ratio_type": "token", "token_tis_ratio_cap_high": 2.0, "rejection_mask_type": "null", + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, } ) @@ -971,6 +992,8 @@ def test_apply_rollout_correction_rejection_only(): "rejection_mask_type": "geometric", "geo_rejection_mask_ratio_cap_high": 1.1, "geo_rejection_mask_ratio_cap_low": 0.9, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, } ) @@ -998,6 +1021,8 @@ def test_apply_rollout_correction_both_enabled(): "rejection_mask_type": "geometric", "geo_rejection_mask_ratio_cap_high": 1.2, "geo_rejection_mask_ratio_cap_low": 0.8, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, } ) @@ -1025,6 +1050,8 @@ def test_apply_rollout_correction_rejection_zeros_loss(): "rejection_mask_type": "geometric", "geo_rejection_mask_ratio_cap_high": 1.1, "geo_rejection_mask_ratio_cap_low": 0.9, + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, } ) @@ -1057,6 +1084,8 @@ def test_ppo_policy_loss_with_rollout_correction(): "tis_ratio_type": "token", "token_tis_ratio_cap_high": 2.0, "rejection_mask_type": "null", + "outlier_token_is_threshold_low": 1e-4, + "outlier_token_is_threshold_high": 100.0, }, } ) From abac80077dcff9d8559cfcefd12d9aee3187e43a Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 01:41:32 +0000 Subject: [PATCH 07/23] add metrics --- skyrl-train/docs/configuration/config.rst | 4 +- .../examples/gsm8k/run_gsm8k_tis_geo_rs.sh | 0 skyrl-train/skyrl_train/utils/ppo_utils.py | 127 +++++++++++++----- .../megatron/megatron_model_wrapper.py | 4 +- skyrl-train/skyrl_train/workers/worker.py | 4 +- .../tests/cpu/algorithms/test_losses.py | 94 ++++++++++--- 6 files changed, 174 insertions(+), 59 deletions(-) create mode 100644 skyrl-train/examples/gsm8k/run_gsm8k_tis_geo_rs.sh diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 1869c1c8e..c98eff625 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -502,7 +502,7 @@ It can be helpful to understand the final loss formulation to see how the differ advantages: torch.Tensor, config: DictConfig, # trainer.algorithm config loss_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, LossMetrics]: ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages @@ -515,7 +515,7 @@ It can be helpful to understand the final loss formulation to see how the differ clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) loss = reduce_loss(loss, loss_mask, config.loss_reduction) - return loss, clip_ratio + return loss, LossMetrics(clip_ratio=clip_ratio) Generator Configuration diff --git a/skyrl-train/examples/gsm8k/run_gsm8k_tis_geo_rs.sh b/skyrl-train/examples/gsm8k/run_gsm8k_tis_geo_rs.sh new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 74308e0d1..a23e7fded 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union, TypedDict import numpy as np import ray @@ -196,6 +196,10 @@ def ppo_critic_loss( return 0.5 * loss, clipfrac +class LossMetrics(TypedDict): + clip_ratio: float + + # Shared registry actor class for both policy loss and advantage estimator registries @ray.remote class RegistryActor: @@ -579,15 +583,24 @@ def compute_tis_ratio( token_tis_log_ratio = old_log_probs - rollout_logprobs token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) + metrics = {} if tis_ratio_type == "token": token_tis_ratio_cap = rollout_corr.token_tis_ratio_cap_high - return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap) + # Compute proportion of tokens capped + tokens_capped = (token_tis_ratio > token_tis_ratio_cap) & (loss_mask > 0) + total_tokens = (loss_mask > 0).sum() + metrics["tis_token_capped_frac"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item() + return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap), metrics elif tis_ratio_type == "sequence": # Compute sequence-level importance ratio as product of token ratios (sum of log ratios) seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) seq_tis_ratio_cap = rollout_corr.sequence_tis_ratio_cap_high - return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap) + # Compute proportion of sequences capped + num_sequences = seq_tis_ratio.shape[0] + seqs_capped = (seq_tis_ratio > seq_tis_ratio_cap).sum() + metrics["tis_seq_capped_frac"] = (seqs_capped / num_sequences).detach().item() + return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap), metrics else: raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") @@ -612,9 +625,11 @@ def compute_outlier_token_mask( rollout_corr: Rollout correction config containing threshold values. Returns: - Outlier token mask tensor (float) to multiply with the loss. - Shape is [batch, 1] (sequence-level mask). + Tuple of (outlier_mask, metrics): + - outlier_mask: Tensor (float) to multiply with the loss, shape [batch, 1] + - metrics: Dict with masking statistics """ + metrics = {} # Compute token-level importance ratio token_tis_log_ratio = old_log_probs - rollout_logprobs token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) @@ -622,12 +637,25 @@ def compute_outlier_token_mask( # Check per-token bounds token_mask_low = rollout_corr.outlier_token_is_threshold_low token_mask_high = rollout_corr.outlier_token_is_threshold_high - token_in_bounds = (token_tis_ratio >= token_mask_low) & (token_tis_ratio <= token_mask_high) + token_over_high = (token_tis_ratio > token_mask_high) & (loss_mask > 0) + token_under_low = (token_tis_ratio < token_mask_low) & (loss_mask > 0) + token_in_bounds = ~token_over_high & ~token_under_low # A sequence is valid if all tokens are in bounds (considering only masked positions) all_tokens_valid = (token_in_bounds | (loss_mask == 0)).all(dim=-1, keepdim=True) - return all_tokens_valid.float() + # Compute metrics + num_sequences = float(all_tokens_valid.shape[0]) + # Sequence has any token over high threshold + seq_has_over_high = token_over_high.any(dim=-1) + # Sequence has any token under low threshold + seq_has_under_low = token_under_low.any(dim=-1) + + metrics["outlier_seq_masked_frac"] = ((~all_tokens_valid.squeeze(-1)).sum() / num_sequences).detach().item() + metrics["outlier_seq_over_high_frac"] = (seq_has_over_high.sum() / num_sequences).detach().item() + metrics["outlier_seq_under_low_frac"] = (seq_has_under_low.sum() / num_sequences).detach().item() + + return all_tokens_valid.float(), metrics def compute_rejection_mask( @@ -636,7 +664,7 @@ def compute_rejection_mask( loss_mask: torch.Tensor, rejection_mask_type: str, rollout_corr: DictConfig, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, dict]: """ Compute rejection mask for rollout correction. @@ -651,10 +679,13 @@ def compute_rejection_mask( rollout_corr: Rollout correction config containing cap values. Returns: - Rejection mask tensor (float) to multiply with the loss. + Tuple of (rejection_mask, metrics): + - rejection_mask: Tensor (float) to multiply with the loss + - metrics: Dict with masking statistics """ # Compute token-level importance ratio token_tis_log_ratio = old_log_probs - rollout_logprobs + metrics = {} if rejection_mask_type == "geometric": # Compute geometric mean of importance ratios per sequence @@ -663,16 +694,32 @@ def compute_rejection_mask( geo_mean_ratio = _safe_exp_delta(seq_tis_log_ratio / num_tokens, clip=20.0, out_dtype=old_log_probs.dtype) geo_cap_high = rollout_corr.geo_rejection_mask_ratio_cap_high geo_cap_low = rollout_corr.geo_rejection_mask_ratio_cap_low - geo_rejection_mask = (geo_mean_ratio >= geo_cap_low) & (geo_mean_ratio <= geo_cap_high) - return geo_rejection_mask.float() + seq_over_high = geo_mean_ratio > geo_cap_high + seq_under_low = geo_mean_ratio < geo_cap_low + geo_rejection_mask = ~seq_over_high & ~seq_under_low + + num_sequences = float(geo_mean_ratio.shape[0]) + metrics["rejection_seq_masked_frac"] = ((~geo_rejection_mask).sum() / num_sequences).detach().item() + metrics["rejection_seq_over_high_frac"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["rejection_seq_under_low_frac"] = (seq_under_low.sum() / num_sequences).detach().item() + + return geo_rejection_mask.float(), metrics elif rejection_mask_type == "sequence": # Mask out sequences with product of importance ratios outside the cap seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) seq_cap_high = rollout_corr.sequence_rejection_mask_ratio_cap_high seq_cap_low = rollout_corr.sequence_rejection_mask_ratio_cap_low - seq_in_bounds = (seq_tis_ratio >= seq_cap_low) & (seq_tis_ratio <= seq_cap_high) - return seq_in_bounds.float() + seq_over_high = seq_tis_ratio > seq_cap_high + seq_under_low = seq_tis_ratio < seq_cap_low + seq_in_bounds = ~seq_over_high & ~seq_under_low + + num_sequences = float(seq_tis_ratio.shape[0]) + metrics["rejection_seq_masked_frac"] = ((~seq_in_bounds).sum() / num_sequences).detach().item() + metrics["rejection_seq_over_high_frac"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["rejection_seq_under_low_frac"] = (seq_under_low.sum() / num_sequences).detach().item() + + return seq_in_bounds.float(), metrics else: raise ValueError(f"Unknown rejection_mask_type: {rejection_mask_type}") @@ -714,26 +761,36 @@ def apply_rollout_correction( # Early return if no correction needed if not apply_tis and not apply_rejection: - return loss + return loss, {} + + is_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=old_log_probs.dtype) + metrics = {} + metrics["is_ratio_mean"] = masked_mean(is_ratio, loss_mask).mean().detach().item() + metrics["is_ratio_std"] = (is_ratio * loss_mask).std().detach().item() + metrics["is_ratio_max"] = (is_ratio * loss_mask).max().detach().item() + metrics["is_ratio_min"] = (is_ratio * loss_mask).min().detach().item() # Apply outlier token mask whenever rollout correction is enabled # This rejects sequences with any token having importance ratio outside acceptable bounds - outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + outlier_mask, outlier_metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, rollout_corr) loss = loss * outlier_mask + metrics.update(outlier_metrics) # Apply TIS ratio if enabled if apply_tis: - tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr) + tis_ratio, tis_metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr) loss = loss * tis_ratio + metrics.update(tis_metrics) # Apply rejection mask if enabled if apply_rejection: - rejection_mask = compute_rejection_mask( + rejection_mask, rejection_metrics = compute_rejection_mask( old_log_probs, rollout_logprobs, loss_mask, rejection_mask_type, rollout_corr ) loss = loss * rejection_mask + metrics.update(rejection_metrics) - return loss + return loss, metrics @register_policy_loss(PolicyLossType.REGULAR) @@ -775,12 +832,16 @@ def ppo_policy_loss( tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) loss = loss * tis_imp_ratio + loss_metrics = LossMetrics(clip_ratio=clip_ratio) + # apply rollout correction rollout_corr = config.rollout_correction - loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + if rollout_corr is not None and rollout_logprobs is not None: + loss, rollout_correction_metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss_metrics.update(rollout_correction_metrics) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) - return loss, clip_ratio + return loss, loss_metrics @register_policy_loss(PolicyLossType.SAPO) @@ -844,15 +905,14 @@ def gate_function(x, tau): # apply rollout correction rollout_corr = config.rollout_correction - loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) - + loss_metrics = LossMetrics(clip_ratio=0.0) + if rollout_corr is not None and rollout_logprobs is not None: + loss, rollout_correction_metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss_metrics.update(rollout_correction_metrics) # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) - # SAPO does not use clipping, so we set clip_ratio to 0.0 for compatibility - clip_ratio = 0.0 - - return loss, clip_ratio + return loss, loss_metrics @register_policy_loss(PolicyLossType.GSPO) @@ -917,7 +977,7 @@ def gspo_policy_loss( loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) - return loss, clip_ratio + return loss, LossMetrics(clip_ratio=clip_ratio) @register_policy_loss(PolicyLossType.CISPO) @@ -947,10 +1007,11 @@ def compute_policy_loss_cispo( # apply rollout correction rollout_corr = config.rollout_correction - loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + if rollout_corr is not None and rollout_logprobs is not None: + loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) - return loss, clip_ratio + return loss, LossMetrics(clip_ratio=clip_ratio) @register_policy_loss(PolicyLossType.CLIP_COV) @@ -1013,7 +1074,8 @@ def compute_policy_loss_clip_cov( # apply rollout correction rollout_corr = config.rollout_correction - pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + if rollout_corr is not None and rollout_logprobs is not None: + pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) pg_loss = reduce_loss( loss=pg_losses, @@ -1022,7 +1084,7 @@ def compute_policy_loss_clip_cov( max_seq_len=config.max_seq_len, ) - return pg_loss, clip_frac.item() + return pg_loss, LossMetrics(clip_frac=clip_frac.item()) @register_policy_loss(PolicyLossType.KL_COV) @@ -1076,7 +1138,8 @@ def compute_policy_loss_kl_cov( # apply rollout correction rollout_corr = config.rollout_correction - pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + if rollout_corr is not None and rollout_logprobs is not None: + pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) pg_loss = reduce_loss( loss=pg_losses, diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index d84673e66..26ebf839a 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -219,7 +219,7 @@ def loss_func(logits, data): action_log_probs = token_logprobs[:, -num_actions:] # policy loss should be calculated based on the selected token logprobs - policy_loss, clip_ratio = self.policy_loss_fn( + policy_loss, loss_metrics = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages, @@ -256,7 +256,7 @@ def loss_func(logits, data): "final_loss": loss.detach().item(), "policy_loss": policy_loss.detach().item(), "policy_entropy": entropy.detach().item(), - "ppo_clip_ratio": clip_ratio, + "ppo_clip_ratio": loss_metrics["clip_ratio"], "policy_kl": kl_loss.detach().item(), } return loss, metrics diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 361a0777e..4335ae779 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -663,7 +663,7 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> ) # loss function # TODO: recompute advantages - policy_loss, clip_ratio = self.policy_loss_fn( + policy_loss, loss_metrics = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages, @@ -704,7 +704,7 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> status = { "final_loss": loss.item(), "policy_loss": policy_loss.item(), - "ppo_clip_ratio": clip_ratio, + "ppo_clip_ratio": loss_metrics["clip_ratio"], "policy_entropy": entropy.item(), "response_length": num_actions, } diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index f026bc5aa..b400a858f 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -24,6 +24,7 @@ "outlier_token_is_threshold_high": 100.0, } + # Adapted a good test from NeMO-RL def test_policy_loss_dual_clip(): """Tests dual clipping in PolicyLoss function.""" @@ -498,7 +499,8 @@ def test_clip_cov_policy_loss(): clip_cov_fn = PolicyLossRegistry.get("clip_cov") # Calculate loss - loss, clip_frac = clip_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) + loss, loss_metrics = clip_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) + clip_frac = loss_metrics["clip_frac"] # Basic sanity checks assert torch.isfinite(loss), "Loss should be finite" @@ -518,7 +520,7 @@ def test_clip_cov_policy_loss(): ) regular_fn = PolicyLossRegistry.get("regular") - regular_loss, regular_clip_frac = regular_fn(log_probs, old_log_probs, advantages, regular_config, loss_mask) + regular_loss, regular_loss_metrics = regular_fn(log_probs, old_log_probs, advantages, regular_config, loss_mask) # Clip-Cov should give different results due to covariance-based correction assert not torch.allclose( @@ -617,7 +619,7 @@ def test_sapo_policy_loss_basic(): loss_fn = PolicyLossRegistry.get("sapo") # Actual SAPO loss - actual_loss, actual_clip_ratio = loss_fn( + actual_loss, loss_metrics = loss_fn( log_probs=log_probs, old_log_probs=old_log_probs, advantages=advantages, @@ -646,7 +648,7 @@ def gate_function(x, tau): torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8) # SAPO should always report clip_ratio = 0.0 - assert actual_clip_ratio == 0.0 + assert loss_metrics["clip_ratio"] == 0.0 # ============================================================================ @@ -671,11 +673,14 @@ def test_compute_tis_ratio_token_level(): } ) - tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "token", config) + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "token", config) # Expected: [1.6487, 0.6065, 2.0] (third token capped at 2.0) expected = torch.tensor([[1.6487, 0.6065, 2.0]], device=device) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # One token out of 3 was capped + assert "tis_token_capped_frac" in metrics + assert abs(metrics["tis_token_capped_frac"] - 1/3) < 0.01 def test_compute_tis_ratio_sequence_level(): @@ -696,11 +701,14 @@ def test_compute_tis_ratio_sequence_level(): } ) - tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) # Expected: exp(1.0) = 2.7183, shape [batch, 1] for sequence-level expected = torch.tensor([[2.7183]], device=device) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # No sequence was capped (2.7183 < 5.0) + assert "tis_seq_capped_frac" in metrics + assert metrics["tis_seq_capped_frac"] == 0.0 def test_compute_tis_ratio_sequence_level_with_cap(): @@ -721,11 +729,14 @@ def test_compute_tis_ratio_sequence_level_with_cap(): } ) - tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) # Expected: capped at 5.0, shape [batch, 1] for sequence-level expected = torch.tensor([[5.0]], device=device) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # One sequence out of 1 was capped + assert "tis_seq_capped_frac" in metrics + assert metrics["tis_seq_capped_frac"] == 1.0 def test_compute_tis_ratio_with_mask(): @@ -745,12 +756,15 @@ def test_compute_tis_ratio_with_mask(): } ) - tis_ratio = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + tis_ratio, metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) # Expected: exp(1.5) = 4.4817, shape [batch, 1] for sequence-level expected_val = torch.exp(torch.tensor(1.5)) expected = expected_val.reshape(1, 1) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) + # No sequence was capped (4.4817 < 10.0) + assert "tis_seq_capped_frac" in metrics + assert metrics["tis_seq_capped_frac"] == 0.0 def test_compute_rejection_mask_geometric(): @@ -770,12 +784,16 @@ def test_compute_rejection_mask_geometric(): } ) - rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) # Geometric mean ≈ 1.0, which is within [0.9, 1.1], so mask should be 1.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[1.0]], device=device) torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked + assert metrics["rejection_seq_masked_frac"] == 0.0 + assert metrics["rejection_seq_over_high_frac"] == 0.0 + assert metrics["rejection_seq_under_low_frac"] == 0.0 def test_compute_rejection_mask_geometric_rejects(): @@ -795,12 +813,16 @@ def test_compute_rejection_mask_geometric_rejects(): } ) - rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) # Geometric mean ≈ 1.6487, which is outside [0.9, 1.1], so mask should be 0.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[0.0]], device=device) torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + # One sequence masked, over high cap + assert metrics["rejection_seq_masked_frac"] == 1.0 + assert metrics["rejection_seq_over_high_frac"] == 1.0 + assert metrics["rejection_seq_under_low_frac"] == 0.0 def test_compute_rejection_mask_sequence(): @@ -820,12 +842,16 @@ def test_compute_rejection_mask_sequence(): } ) - rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) # Sequence ratio ≈ 1.35, which is within [0.5, 2.0] # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[1.0]], device=device) torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked + assert metrics["rejection_seq_masked_frac"] == 0.0 + assert metrics["rejection_seq_over_high_frac"] == 0.0 + assert metrics["rejection_seq_under_low_frac"] == 0.0 def test_compute_rejection_mask_sequence_rejects_by_seq_ratio(): @@ -845,12 +871,16 @@ def test_compute_rejection_mask_sequence_rejects_by_seq_ratio(): } ) - rejection_mask = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) # Sequence ratio ≈ 20.09, which is outside [0.5, 2.0], so mask should be 0.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[0.0]], device=device) torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + # One sequence masked, over high cap + assert metrics["rejection_seq_masked_frac"] == 1.0 + assert metrics["rejection_seq_over_high_frac"] == 1.0 + assert metrics["rejection_seq_under_low_frac"] == 0.0 def test_compute_outlier_token_mask_rejects_by_token_bounds(): @@ -870,12 +900,16 @@ def test_compute_outlier_token_mask_rejects_by_token_bounds(): } ) - outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) # Token ratio 148.4 > 100.0, so mask should be 0.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[0.0]], device=device) torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + # One sequence masked, has token over high threshold + assert metrics["outlier_seq_masked_frac"] == 1.0 + assert metrics["outlier_seq_over_high_frac"] == 1.0 + assert metrics["outlier_seq_under_low_frac"] == 0.0 def test_compute_outlier_token_mask_accepts_in_bounds(): @@ -895,12 +929,16 @@ def test_compute_outlier_token_mask_accepts_in_bounds(): } ) - outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) # All token ratios in bounds, so mask should be 1.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[1.0]], device=device) torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked + assert metrics["outlier_seq_masked_frac"] == 0.0 + assert metrics["outlier_seq_over_high_frac"] == 0.0 + assert metrics["outlier_seq_under_low_frac"] == 0.0 def test_compute_outlier_token_mask_respects_loss_mask(): @@ -920,11 +958,13 @@ def test_compute_outlier_token_mask_respects_loss_mask(): } ) - outlier_mask = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) + outlier_mask, metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, config) # Third token is masked, so even though ratio is out of bounds, sequence should be accepted expected = torch.tensor([[1.0]], device=device) torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) + # No sequence was masked (the out-of-bounds token was in a masked position) + assert metrics["outlier_seq_masked_frac"] == 0.0 def test_apply_rollout_correction_null_configs(): @@ -943,10 +983,11 @@ def test_apply_rollout_correction_null_configs(): } ) - corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) - # Should return the same tensor (early return) + # Should return the same tensor (early return) and empty metrics assert corrected_loss is loss + assert metrics == {} def test_apply_rollout_correction_tis_only(): @@ -969,11 +1010,14 @@ def test_apply_rollout_correction_tis_only(): } ) - corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) # Expected: loss * 1.6487 (no capping needed) expected = loss * torch.exp(torch.tensor(0.5)) torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + # Check metrics are populated + assert "is_ratio_mean" in metrics + assert "tis_token_capped_frac" in metrics def test_apply_rollout_correction_rejection_only(): @@ -997,10 +1041,13 @@ def test_apply_rollout_correction_rejection_only(): } ) - corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) # Geometric mean = 1.0, within bounds, so loss unchanged torch.testing.assert_close(corrected_loss, loss, rtol=1e-3, atol=1e-4) + # Check metrics are populated + assert "is_ratio_mean" in metrics + assert "rejection_seq_masked_frac" in metrics def test_apply_rollout_correction_both_enabled(): @@ -1026,12 +1073,15 @@ def test_apply_rollout_correction_both_enabled(): } ) - corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) # TIS ratio ≈ 1.105, geometric mean ≈ 1.105 (within bounds, mask=1) # Expected: loss * 1.105 * 1.0 = loss * 1.105 expected = loss * torch.exp(torch.tensor(0.1)) torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + # Check metrics from both TIS and rejection are populated + assert "tis_token_capped_frac" in metrics + assert "rejection_seq_masked_frac" in metrics def test_apply_rollout_correction_rejection_zeros_loss(): @@ -1055,11 +1105,13 @@ def test_apply_rollout_correction_rejection_zeros_loss(): } ) - corrected_loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss should be zeroed expected = torch.tensor([[0.0, 0.0, 0.0]], device=device) torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + # Check that the rejection metrics show rejection happened + assert metrics["rejection_seq_masked_frac"] == 1.0 def test_ppo_policy_loss_with_rollout_correction(): From 2dc73641a1de9fa80d61f271bab655815251bc5c Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 07:28:38 +0000 Subject: [PATCH 08/23] propagate metrics up and refactor how we do metric reductions for max and min --- skyrl-train/examples/gsm8k/run_gsm8k.sh | 13 ++++++- .../run_megatron_dapo_qwen3_30b_a3b_lora.sh | 24 ++++++++----- .../skyrl_train/distributed/strategy.py | 17 +++++++-- skyrl-train/skyrl_train/utils/ppo_utils.py | 36 +++++++++++++------ skyrl-train/skyrl_train/utils/utils.py | 14 ++++---- skyrl-train/skyrl_train/workers/worker.py | 15 ++++++-- .../skyrl_train/workers/worker_utils.py | 7 +++- .../tests/cpu/algorithms/test_losses.py | 24 +++++++++---- .../tests/gpu/gpu_ci/test_ppo_train.py | 11 +++++- skyrl-train/tests/gpu/utils.py | 1 + 10 files changed, 122 insertions(+), 40 deletions(-) diff --git a/skyrl-train/examples/gsm8k/run_gsm8k.sh b/skyrl-train/examples/gsm8k/run_gsm8k.sh index 6b558da62..95c54710e 100755 --- a/skyrl-train/examples/gsm8k/run_gsm8k.sh +++ b/skyrl-train/examples/gsm8k/run_gsm8k.sh @@ -10,17 +10,28 @@ set -x # You can override the default values with e.g.: `NUM_GPUS=1 bash examples/gsm8k/run_gsm8k.sh`. -: "${DATA_DIR:="$HOME/data/gsm8k"}" +: "${DATA_DIR:="/mnt/cluster_storage/data/gsm8k"}" : "${NUM_GPUS:=4}" : "${LOGGER:=wandb}" # change to "console" to print to stdout : "${INFERENCE_BACKEND:=vllm}" # : "${INFERENCE_BACKEND:=sglang}" + +# rollout correction parameters +TIS_RATIO_TYPE="sequence" +REJECTION_MASK_TYPE="geometric" +GEO_REJECTION_MASK_RATIO_CAP_HIGH=1.01 +GEO_REJECTION_MASK_RATIO_CAP_LOW=0.99 + uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ trainer.algorithm.advantage_estimator="grpo" \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_RATIO_TYPE \ + trainer.algorithm.rollout_correction.rejection_mask_type=$REJECTION_MASK_TYPE \ + trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high=$GEO_REJECTION_MASK_RATIO_CAP_HIGH \ + trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low=$GEO_REJECTION_MASK_RATIO_CAP_LOW \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh index 904343576..61514f9b4 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh @@ -7,7 +7,7 @@ set -x # bash examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh MODEL_NAME="Qwen/Qwen3-30B-A3B-Base" -DATA_DIR="$HOME/data/dapo" +DATA_DIR="/mnt/cluster_storage/data/dapo" TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet" TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet" NUM_NODES=2 @@ -55,8 +55,13 @@ LORA_RANK=32 LORA_ALPHA=64 # TIS parameters -TIS_IMP_RATIO_CAP=2.0 -TIS_TYPE=token +TIS_IMP_RATIO_CAP=3.0 + +# rollout correction parameters +TIS_RATIO_TYPE="sequence" +REJECTION_MASK_TYPE="geometric" +GEO_REJECTION_MASK_RATIO_CAP_HIGH=1.05 +GEO_REJECTION_MASK_RATIO_CAP_LOW=0.95 uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -88,8 +93,11 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_RATIO_TYPE \ + trainer.algorithm.rollout_correction.sequence_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.rollout_correction.rejection_mask_type=$REJECTION_MASK_TYPE \ + trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high=$GEO_REJECTION_MASK_RATIO_CAP_HIGH \ + trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low=$GEO_REJECTION_MASK_RATIO_CAP_LOW \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ @@ -119,10 +127,10 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ generator.gpu_memory_utilization=0.7 \ trainer.logger="$LOGGER" \ trainer.project_name="dapo_aime" \ - trainer.run_name="dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}" \ - trainer.export_path="$HOME/exports/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}" \ + trainer.run_name="dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${REJECTION_MASK_TYPE}_${GEO_REJECTION_MASK_RATIO_CAP_HIGH}_${GEO_REJECTION_MASK_RATIO_CAP_LOW}" \ + trainer.export_path="$HOME/exports/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${REJECTION_MASK_TYPE}_${GEO_REJECTION_MASK_RATIO_CAP_HIGH}_${GEO_REJECTION_MASK_RATIO_CAP_LOW}" \ trainer.hf_save_interval=300 \ trainer.resume_mode=latest \ trainer.max_ckpts_to_keep=3 \ - trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}" \ + trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${REJECTION_MASK_TYPE}_${GEO_REJECTION_MASK_RATIO_CAP_HIGH}_${GEO_REJECTION_MASK_RATIO_CAP_LOW}" \ $@ \ No newline at end of file diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index acceccb45..785722c93 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -69,10 +69,15 @@ def get_rank(self) -> int: def all_reduce(self, data: DataT, op="mean") -> DataT: """Perform all_reduce across all processes""" - assert op in ("mean", "max", "sum") + assert op in ("mean", "max", "sum", "min") if isinstance(data, dict): ret = {} for k, v in data.items(): + options = ["min", "max", "mean"] + for op in options: + if op in k: + op = op + break ret[k] = self.all_reduce(v, op) return ret else: @@ -86,7 +91,15 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: data = data.to(torch.cuda.current_device()) if op == "mean": data /= self.world_size - dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM) + dist.all_reduce(data, op=dist.ReduceOp.SUM) + elif op == "max": + data = torch.max(data) + dist.all_reduce(data, op=dist.ReduceOp.MAX) + elif op == "min": + data = torch.min(data) + dist.all_reduce(data, op=dist.ReduceOp.MIN) + elif op == "sum": + dist.all_reduce(data, op=dist.ReduceOp.SUM) if is_cpu_tensor: data = data.cpu() return data.item() if not is_tensor else data diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index a23e7fded..9d663c866 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -761,7 +761,7 @@ def apply_rollout_correction( # Early return if no correction needed if not apply_tis and not apply_rejection: - return loss, {} + return loss, {}, loss_mask is_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=old_log_probs.dtype) metrics = {} @@ -773,12 +773,14 @@ def apply_rollout_correction( # Apply outlier token mask whenever rollout correction is enabled # This rejects sequences with any token having importance ratio outside acceptable bounds outlier_mask, outlier_metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, rollout_corr) - loss = loss * outlier_mask + loss_mask = loss_mask * outlier_mask metrics.update(outlier_metrics) # Apply TIS ratio if enabled if apply_tis: - tis_ratio, tis_metrics = compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr) + tis_ratio, tis_metrics = compute_tis_ratio( + old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr + ) loss = loss * tis_ratio metrics.update(tis_metrics) @@ -787,10 +789,10 @@ def apply_rollout_correction( rejection_mask, rejection_metrics = compute_rejection_mask( old_log_probs, rollout_logprobs, loss_mask, rejection_mask_type, rollout_corr ) - loss = loss * rejection_mask + loss_mask = loss_mask * rejection_mask metrics.update(rejection_metrics) - return loss, metrics + return loss, metrics, loss_mask @register_policy_loss(PolicyLossType.REGULAR) @@ -837,7 +839,9 @@ def ppo_policy_loss( # apply rollout correction rollout_corr = config.rollout_correction if rollout_corr is not None and rollout_logprobs is not None: - loss, rollout_correction_metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss, rollout_correction_metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + ) loss_metrics.update(rollout_correction_metrics) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) @@ -907,7 +911,9 @@ def gate_function(x, tau): rollout_corr = config.rollout_correction loss_metrics = LossMetrics(clip_ratio=0.0) if rollout_corr is not None and rollout_logprobs is not None: - loss, rollout_correction_metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss, rollout_correction_metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + ) loss_metrics.update(rollout_correction_metrics) # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) @@ -970,7 +976,9 @@ def gspo_policy_loss( # apply rollout correction rollout_corr = config.rollout_correction if rollout_corr is not None and rollout_logprobs is not None: - loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss, loss_metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + ) # Compute clipping ratio for monitoring clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() @@ -1008,7 +1016,9 @@ def compute_policy_loss_cispo( # apply rollout correction rollout_corr = config.rollout_correction if rollout_corr is not None and rollout_logprobs is not None: - loss = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + loss, loss_metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + ) loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) return loss, LossMetrics(clip_ratio=clip_ratio) @@ -1075,7 +1085,9 @@ def compute_policy_loss_clip_cov( # apply rollout correction rollout_corr = config.rollout_correction if rollout_corr is not None and rollout_logprobs is not None: - pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + pg_losses, loss_metrics, loss_mask = apply_rollout_correction( + pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + ) pg_loss = reduce_loss( loss=pg_losses, @@ -1139,7 +1151,9 @@ def compute_policy_loss_kl_cov( # apply rollout correction rollout_corr = config.rollout_correction if rollout_corr is not None and rollout_logprobs is not None: - pg_losses = apply_rollout_correction(pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + pg_losses, loss_metrics, loss_mask = apply_rollout_correction( + pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + ) pg_loss = reduce_loss( loss=pg_losses, diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 7051d8671..47677ff1c 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -305,25 +305,25 @@ def validate_cfg(cfg: DictConfig): # New rollout_correction validation rollout_corr = cfg.trainer.algorithm.get("rollout_correction", None) if rollout_corr is not None: - tis_ratio_type = rollout_corr.get("tis_ratio_type", "null") - rejection_mask_type = rollout_corr.get("rejection_mask_type", "null") + tis_ratio_type = rollout_corr.tis_ratio_type + rejection_mask_type = rollout_corr.rejection_mask_type - uses_rollout_correction = tis_ratio_type != "null" or rejection_mask_type != "null" + uses_rollout_correction = tis_ratio_type is not None or rejection_mask_type is not None if uses_rollout_correction: # Validate tis_ratio_type assert tis_ratio_type in [ - "null", + None, "token", "sequence", - ], f"`tis_ratio_type` must be 'null', 'token', or 'sequence', got {tis_ratio_type}" + ], f"`tis_ratio_type` must be 'None', 'token', or 'sequence', got {tis_ratio_type}" # Validate rejection_mask_type assert rejection_mask_type in [ - "null", + None, "sequence", "geometric", - ], f"`rejection_mask_type` must be 'null', 'sequence', or 'geometric', got {rejection_mask_type}" + ], f"`rejection_mask_type` must be 'sequence', or 'geometric', got {rejection_mask_type}" # Ensure logprobs are enabled for rollout correction if cfg.generator.sampling_params.logprobs is None: diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 4335ae779..898fa8cc7 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -708,6 +708,9 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> "policy_entropy": entropy.item(), "response_length": num_actions, } + for k, v in loss_metrics.items(): + if k != "clip_ratio": # we separately name the clip ratio metric + status["loss_metrics/" + k] = v if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() @@ -741,7 +744,16 @@ def record_status(status: Dict[str, float]): # for DP # TODO (sumanthrh): this assumes all workers are data parallel. # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. - status = self.strategy.all_reduce(status) + min_metrics = {k: v for k, v in status.items() if k.endswith("_min")} + max_metrics = {k: v for k, v in status.items() if k.endswith("_max")} + mean_metrics = {k: v for k, v in status.items() if k not in min_metrics and k not in max_metrics} + + status_mean = self.strategy.all_reduce(mean_metrics, op="mean") + status_min = self.strategy.all_reduce(min_metrics, op="min") + status_max = self.strategy.all_reduce(max_metrics, op="max") + status_mean.update(status_min) + status_mean.update(status_max) + status = status_mean # weighted mean for kl # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. @@ -817,7 +829,6 @@ def record_status(status: Dict[str, float]): torch.distributed.barrier() # not needed beyond status logging all_metrics.pop("response_length", None) - status_mean = reduce_metrics(all_metrics) status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 897d032ea..28dbb5888 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -12,7 +12,12 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: for k, v in metrics.items(): assert len(v) > 0, f"No metrics for key {k}" assert all(isinstance(x, (int, float)) for x in v), f"Metrics for key {k} are not all numbers" - reduced_metrics[k] = sum(v) / len(v) + if k.endswith("_max"): + reduced_metrics[k] = max(v) + elif k.endswith("_min"): + reduced_metrics[k] = min(v) + else: + reduced_metrics[k] = sum(v) / len(v) return reduced_metrics diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index b400a858f..ad52d74d8 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -680,7 +680,7 @@ def test_compute_tis_ratio_token_level(): torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) # One token out of 3 was capped assert "tis_token_capped_frac" in metrics - assert abs(metrics["tis_token_capped_frac"] - 1/3) < 0.01 + assert abs(metrics["tis_token_capped_frac"] - 1 / 3) < 0.01 def test_compute_tis_ratio_sequence_level(): @@ -983,7 +983,9 @@ def test_apply_rollout_correction_null_configs(): } ) - corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, config + ) # Should return the same tensor (early return) and empty metrics assert corrected_loss is loss @@ -1010,7 +1012,9 @@ def test_apply_rollout_correction_tis_only(): } ) - corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, config + ) # Expected: loss * 1.6487 (no capping needed) expected = loss * torch.exp(torch.tensor(0.5)) @@ -1041,7 +1045,9 @@ def test_apply_rollout_correction_rejection_only(): } ) - corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, config + ) # Geometric mean = 1.0, within bounds, so loss unchanged torch.testing.assert_close(corrected_loss, loss, rtol=1e-3, atol=1e-4) @@ -1073,7 +1079,9 @@ def test_apply_rollout_correction_both_enabled(): } ) - corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, config + ) # TIS ratio ≈ 1.105, geometric mean ≈ 1.105 (within bounds, mask=1) # Expected: loss * 1.105 * 1.0 = loss * 1.105 @@ -1105,11 +1113,13 @@ def test_apply_rollout_correction_rejection_zeros_loss(): } ) - corrected_loss, metrics = apply_rollout_correction(loss, old_log_probs, rollout_logprobs, loss_mask, config) + corrected_loss, metrics, loss_mask = apply_rollout_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, config + ) # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss should be zeroed expected = torch.tensor([[0.0, 0.0, 0.0]], device=device) - torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(corrected_loss * loss_mask, expected, rtol=1e-3, atol=1e-4) # Check that the rejection metrics show rejection happened assert metrics["rejection_seq_masked_frac"] == 1.0 diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py index c9880a017..ff5343320 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py @@ -29,7 +29,8 @@ def cfg() -> DictConfig: return cfg -@pytest.mark.parametrize("use_entropy_loss, use_kl_loss", [(False, False), (True, True), (True, False), (False, True)]) +# @pytest.mark.parametrize("use_entropy_loss, use_kl_loss", [(False, False), (True, True), (True, False), (False, True)]) +@pytest.mark.parametrize("use_entropy_loss, use_kl_loss", [(False, False)]) def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_kl_loss): """ Test that ppo_train runs and returns correct structure. @@ -48,6 +49,12 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ cfg.trainer.algorithm.use_kl_loss = True cfg.trainer.algorithm.kl_loss_coef = 0.001 + cfg.trainer.algorithm.rollout_correction.tis_ratio_type = "sequence" + + cfg.trainer.algorithm.rollout_correction.rejection_mask_type = "geometric" + cfg.trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high = 1.02 + cfg.trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low = 0.98 + actor_group = init_worker_with_type( "policy", shared_pg=None, @@ -83,6 +90,8 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ assert metric in train_status, f"Should have {metric} in train_status" assert isinstance(train_status[metric], (int, float)), f"{metric} should be numeric" + print(train_status) + # Simple check for metric values assert train_status["policy_update_steps"] > 0, "Should have completed at least one update step" assert train_status["policy_lr"] > 0, "Should have positive learning rate" diff --git a/skyrl-train/tests/gpu/utils.py b/skyrl-train/tests/gpu/utils.py index d819604f7..0c1c3097d 100644 --- a/skyrl-train/tests/gpu/utils.py +++ b/skyrl-train/tests/gpu/utils.py @@ -77,6 +77,7 @@ def make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4) -> Traini "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), "loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "rollout_logprobs": 0.2 * torch.ones((batch_size, num_actions), device="cpu"), } ) data.metadata = {"response_length": num_actions} From 349369db44c6c4059d87864772565f56261530da Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 18:57:07 +0000 Subject: [PATCH 09/23] make default null and propagate megatron metrics --- .../skyrl_train/config/ppo_base_config.yaml | 4 +-- .../megatron/megatron_model_wrapper.py | 3 +++ .../workers/megatron/megatron_worker.py | 26 +++++++++++++------ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 2234aca5c..26e018550 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -129,7 +129,7 @@ trainer: rollout_correction: # type of importance ratio to use for ppo loss correction # here importance ratio refers to logprobs_{policy_old} - logprobs_{rollout_policy} - tis_ratio_type: "null" # "null", "token", "sequence" + tis_ratio_type: null # null, "token", "sequence" # cap for the importance ratio # 1.5-5.0 is recommended for "token" tis_ratio_type @@ -139,7 +139,7 @@ trainer: # "sequence" mask masks out sequences with product of importance ratios outside the cap # "geometric" mask masks out sequences with geometric mean of importance ratios outside the cap - rejection_mask_type: "null" # "null", "sequence", "geometric" + rejection_mask_type: null # null, "sequence", "geometric" # cap for the rejection mask ratio # values around 0.99-1.01 are recommended for "geometric" rejection_mask_type - MoE models may need larger allowed ranges due to higher mismatch diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index 26ebf839a..76d7e284f 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -259,6 +259,9 @@ def loss_func(logits, data): "ppo_clip_ratio": loss_metrics["clip_ratio"], "policy_kl": kl_loss.detach().item(), } + for k, v in loss_metrics.items(): + if k != "clip_ratio": + metrics["loss_metrics/" + k] = v return loss, metrics def forward_step(batch_iter, model): diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 080b9bb42..cc5e1e10d 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -583,13 +583,11 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # within a DP group, metrics are already the same across all workers - we then just all reduce across # the whole world size to get the metrics for the global micro batch for i, metrics in enumerate(metrics_list): - status = { - "final_loss": metrics["final_loss"], - "policy_loss": metrics["policy_loss"], - "policy_lr": self.optimizer.param_groups[0]["lr"], - "ppo_clip_ratio": metrics["ppo_clip_ratio"], - "policy_entropy": metrics["policy_entropy"], - } + status = {} + for k, v in metrics.items(): + status[k] = v + + status["policy_lr"] = self.optimizer.param_groups[0]["lr"] if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = metrics["policy_kl"] @@ -600,7 +598,19 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # attach response_length status["response_length"] = micro_buffer[i]["num_actions"] - status = self.strategy.all_reduce(status) + min_metrics = {k: v for k, v in status.items() if k.endswith("_min")} + max_metrics = {k: v for k, v in status.items() if k.endswith("_max")} + mean_metrics = { + k: v for k, v in status.items() if k not in min_metrics and k not in max_metrics + } + + status_mean = self.strategy.all_reduce(mean_metrics, op="mean") + status_min = self.strategy.all_reduce(min_metrics, op="min") + status_max = self.strategy.all_reduce(max_metrics, op="max") + status_mean.update(status_min) + status_mean.update(status_max) + status = status_mean + status_list.append(status) for k, v in status.items(): all_metrics[k].append(v) From f3f7054afee799908c0a882c9d3bfd8f86769656 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 19:06:44 +0000 Subject: [PATCH 10/23] x: --- .../skyrl_train/config/ppo_base_config.yaml | 3 +- skyrl-train/skyrl_train/utils/ppo_utils.py | 10 --- skyrl-train/skyrl_train/utils/utils.py | 75 ++++++++----------- 3 files changed, 33 insertions(+), 55 deletions(-) diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 2234aca5c..a58cfbbdf 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -121,7 +121,8 @@ trainer: # dual clip parameters clip_ratio_c: 3.0 - # mark for deprecation + # To be deprecated in favor of rollout_correction.tis_ratio_type = "token" + # and "token_tis_ratio_cap_high" tis_imp_ratio_cap: -1.0 use_tis: false diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 9d663c866..ed50480d9 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -824,16 +824,6 @@ def ppo_policy_loss( clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - # Legacy TIS support (deprecated) - if config.use_tis: - from loguru import logger as logger_ - - logger_.debug(f"Using TIS with dtype: {rollout_logprobs.dtype}") - # Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl - tis_imp_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype) - tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) - loss = loss * tis_imp_ratio - loss_metrics = LossMetrics(clip_ratio=clip_ratio) # apply rollout correction diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 47677ff1c..1e9b841be 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -284,64 +284,51 @@ def validate_cfg(cfg: DictConfig): # Legacy TIS validation (deprecated) if cfg.trainer.algorithm.use_tis: logger.warning( - "`trainer.algorithm.use_tis` is deprecated. Please use `trainer.algorithm.rollout_correction` instead." + f"`trainer.algorithm.use_tis` is deprecated. Setting `trainer.algorithm.rollout_correction` to `token` instead." + f"with `token_tis_ratio_cap_high`={cfg.trainer.algorithm.tis_imp_ratio_cap}" ) - if cfg.trainer.algorithm.tis_imp_ratio_cap <= 0: - raise ValueError( - f"If `trainer.algorithm.use_tis` is `True` then `cfg.trainer.algorithm.tis_imp_ratio_cap` " - f"should be > 0, got {cfg.trainer.algorithm.tis_imp_ratio_cap }" - ) - if cfg.generator.sampling_params.logprobs is None: - logger.warning( - "`generator.sampling_params.logprobs` is `None` but `trainer.algorithm.use_tis` is `True`." - " Setting `logprobs` to `True`." - ) - # just set to 0 for better user exp - cfg.generator.sampling_params.logprobs = 0 - - if cfg.generator.backend == "sglang": - raise NotImplementedError("`trainer.algorithm.use_tis` doesn't support Sglang backend, please use vLLM") + cfg.trainer.algorithm.rollout_correction.tis_ratio_type = "token" + cfg.trainer.algorithm.rollout_correction.token_tis_ratio_cap_high = cfg.trainer.algorithm.tis_imp_ratio_cap - # New rollout_correction validation - rollout_corr = cfg.trainer.algorithm.get("rollout_correction", None) - if rollout_corr is not None: - tis_ratio_type = rollout_corr.tis_ratio_type - rejection_mask_type = rollout_corr.rejection_mask_type + # rollout_correction config validation + rollout_corr = cfg.trainer.algorithm.rollout_correction + tis_ratio_type = rollout_corr.tis_ratio_type + rejection_mask_type = rollout_corr.rejection_mask_type - uses_rollout_correction = tis_ratio_type is not None or rejection_mask_type is not None + uses_rollout_correction = tis_ratio_type is not None or rejection_mask_type is not None - if uses_rollout_correction: - # Validate tis_ratio_type + if uses_rollout_correction: + # Validate tis_ratio_type + if tis_ratio_type: assert tis_ratio_type in [ - None, "token", "sequence", ], f"`tis_ratio_type` must be 'None', 'token', or 'sequence', got {tis_ratio_type}" - # Validate rejection_mask_type + # Validate rejection_mask_type + if rejection_mask_type: assert rejection_mask_type in [ - None, "sequence", "geometric", ], f"`rejection_mask_type` must be 'sequence', or 'geometric', got {rejection_mask_type}" - # Ensure logprobs are enabled for rollout correction - if cfg.generator.sampling_params.logprobs is None: - logger.warning( - "`generator.sampling_params.logprobs` is `None` but rollout_correction is enabled." - " Setting `logprobs` to `True`." - ) - cfg.generator.sampling_params.logprobs = 0 - - if cfg.generator.backend == "sglang": - raise NotImplementedError( - "`trainer.algorithm.rollout_correction` doesn't support Sglang backend, please use vLLM" - ) - - assert cfg.trainer.algorithm.policy_loss_type in [ - "regular", - "dual_clip", - ], "rollout_correction is only implemented for regular and dual_clip policy loss types" + # Ensure logprobs are enabled for rollout correction + if cfg.generator.sampling_params.logprobs is None: + logger.warning( + "`generator.sampling_params.logprobs` is `None` but rollout_correction is enabled." + " Setting `logprobs` to `True`." + ) + cfg.generator.sampling_params.logprobs = 0 + + if cfg.generator.backend == "sglang": + raise NotImplementedError( + "`trainer.algorithm.rollout_correction` doesn't support Sglang backend, please use vLLM" + ) + + assert cfg.trainer.algorithm.policy_loss_type in [ + "regular", + "dual_clip", + ], "rollout_correction is only implemented for regular and dual_clip policy loss types" if cfg.trainer.policy.model.lora.rank > 0: # LoRA enabled From 63d38c513e957036443693a0eb65d9559d581385 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 19:56:32 +0000 Subject: [PATCH 11/23] big cleanup - remove clip_ratio return (fix custom algorithms stuff), unite metrics under loss_metrics, other clean up --- .../docs/algorithms/custom_algorithms.rst | 4 ++-- skyrl-train/docs/configuration/config.rst | 11 +++++++++ .../main_custom_policy_loss.py | 4 ++-- .../run_dapo_gsm8k_flashrl_0.5b_fp8.sh | 5 +++- .../run_dapo_gsm8k_flashrl_0.5b_int8.sh | 6 ++++- .../run_dapo_gsm8k_flashrl_32b_int8.sh | 7 ++++-- .../run_dapo_repro_flashrl_0.5b_int8.sh | 6 +++-- .../examples/fully_async/async_run_gsm8k.sh | 2 +- .../main_on_policy_distill.py | 3 ++- .../skyrl_train/config/ppo_base_config.yaml | 23 +++++++++++-------- skyrl-train/skyrl_train/trainer.py | 6 ----- skyrl-train/skyrl_train/utils/utils.py | 5 ---- .../megatron/megatron_model_wrapper.py | 4 +--- skyrl-train/skyrl_train/workers/worker.py | 4 +--- .../tests/cpu/algorithms/test_losses.py | 12 ---------- skyrl-train/tests/cpu/test_trainer.py | 7 +++++- skyrl-train/tests/cpu/utils/test_ppo_utils.py | 23 ++++++++++--------- .../tests/gpu/gpu_ci/test_megatron_worker.py | 22 ++++++++++++++---- .../tests/gpu/gpu_ci/test_ppo_train.py | 2 +- .../tests/gpu/gpu_ci/test_training_step.py | 2 +- 20 files changed, 89 insertions(+), 69 deletions(-) diff --git a/skyrl-train/docs/algorithms/custom_algorithms.rst b/skyrl-train/docs/algorithms/custom_algorithms.rst index 09030ef79..f7495b2d0 100644 --- a/skyrl-train/docs/algorithms/custom_algorithms.rst +++ b/skyrl-train/docs/algorithms/custom_algorithms.rst @@ -48,14 +48,14 @@ Similarly, you can register custom policy loss functions: .. code-block:: python - from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry + from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry, LossMetrics @register_policy_loss("reinforce") def compute_reinforce_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None): # Your custom policy loss implementation (like REINFORCE) loss = (-log_probs * advantages).mean() # return loss and clip ratio - return loss, 0.0 + return loss, LossMetrics(clip_ratio=0.0) Registry Ray Distribution ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index c98eff625..0c8e7baf8 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -489,6 +489,17 @@ Algorithm Configuration - ``tau_pos``: Temperature for gating function for tokens with positive advantages. - ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages. +Rollout Correction Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- ``algorithm.rollout_correction``: Rollout correction configuration. + - ``algorithm.rollout_correction.tis_ratio_type``: Type of importance ratio to use for rollout correction. Options include: ``token``, ``sequence``, or ``null``. + - ``algorithm.rollout_correction.token_tis_ratio_cap_high``: Cap for the importance ratio for ``token`` tis_ratio_type. + - ``algorithm.rollout_correction.sequence_tis_ratio_cap_high``: Cap for the importance ratio for ``sequence`` tis_ratio_type. + - ``algorithm.rollout_correction.rejection_mask_type``: Type of rejection mask to use. Options include: ``sequence``, ``geometric``, or ``null``. + - ``algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high``: Cap for the rejection mask ratio for ``geometric`` rejection_mask_type. + - ``algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low``: Cap for the rejection mask ratio for ``geometric`` rejection_mask_type. + - ``algorithm.rollout_correction.sequence_rejection_mask_ratio_cap_high``: Cap for the rejection mask ratio for ``sequence`` rejection_mask_type. + Policy Loss Formulation ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py b/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py index 6b9be8b13..b4a172034 100644 --- a/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py +++ b/skyrl-train/examples/algorithms/custom_policy_loss/main_custom_policy_loss.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig from skyrl_train.utils import initialize_ray from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg -from skyrl_train.utils.ppo_utils import PolicyLossRegistry +from skyrl_train.utils.ppo_utils import PolicyLossRegistry, LossMetrics # Example of custom policy loss: "reinforce" @@ -27,7 +27,7 @@ def compute_reinforce_policy_loss( loss = (-log_probs * advantages).mean() # Return loss and dummy clip_ratio (no clipping in REINFORCE) - return loss, 0.0 + return loss, LossMetrics(clip_ratio=0.0) # Register the custom policy loss diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh index d0aff21ab..c9a3b500b 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh @@ -33,6 +33,8 @@ MAX_RESPONSE_LENGTH=1024 CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt" +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/train.parquet']" \ @@ -53,7 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with v generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.algorithm.tis_imp_ratio_cap=2.0 \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh index f547788f5..f5d90cfc9 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024 CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt" +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 + uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ @@ -52,7 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.algorithm.tis_imp_ratio_cap=2.0 \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh index 7c6c7fb4a..ca9375b0e 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024 CKPT_PATH="$HOME/ckpts/gsm8k_32B_ckpt" +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 + uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ @@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-32B" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh index 97ca62fde..a44845086 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh @@ -31,6 +31,8 @@ EVAL_TOP_P=0.7 CLIP_RATIO_C=10.0 MAX_RESPONSE_LENGTH=4096 +TIS_TYPE=token +TIS_IMP_RATIO_CAP=2.0 uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \ data.train_data="['$DATA_DIR/dapo-math-17k-cleaned.parquet']" \ data.val_data="['$DATA_DIR/aime-2024-cleaned.parquet']" \ @@ -50,8 +52,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 -- trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.use_tis=true \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ + trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/fully_async/async_run_gsm8k.sh b/skyrl-train/examples/fully_async/async_run_gsm8k.sh index a26e4dd87..8cc7ff9f7 100644 --- a/skyrl-train/examples/fully_async/async_run_gsm8k.sh +++ b/skyrl-train/examples/fully_async/async_run_gsm8k.sh @@ -31,7 +31,7 @@ set -x TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 -RUN_NAME=gsm8k-async-qwen2.5_1.5B-useTIS_${USE_TIS}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen +RUN_NAME=gsm8k-async-qwen2.5_1.5B-TIS_TYPE_${TIS_TYPE}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async \ data.train_data="['$DATA_DIR/train.parquet']" \ diff --git a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py index c7030cbcc..4afa470af 100644 --- a/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py +++ b/skyrl-train/examples/on_policy_distillation/main_on_policy_distill.py @@ -10,6 +10,7 @@ register_advantage_estimator, register_policy_loss, reduce_loss, + LossMetrics, ) from skyrl_train.training_batch import TrainingInputBatch @@ -53,7 +54,7 @@ def compute_importance_sampling_policy_loss( loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len) # return loss and a dummy clip ratio value as we aren't clipping here - return loss, 0.0 + return loss, LossMetrics(clip_ratio=0.0) class OnPolicyDistillationExp(BasePPOExp): diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index e36266bf2..dfbd04eae 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -128,30 +128,33 @@ trainer: # reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md rollout_correction: - # type of importance ratio to use for ppo loss correction - # here importance ratio refers to logprobs_{policy_old} - logprobs_{rollout_policy} + # type of importance sampling ratio to use for ppo loss correction + # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy}) tis_ratio_type: null # null, "token", "sequence" - # cap for the importance ratio - # 1.5-5.0 is recommended for "token" tis_ratio_type + # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type token_tis_ratio_cap_high: 2.0 - # 2.0-10.0 is recommended for "sequence" tis_ratio_type + # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type sequence_tis_ratio_cap_high: 5.0 + # method of masking out sequences with cumulative importance sampling ratios outside the cap # "sequence" mask masks out sequences with product of importance ratios outside the cap # "geometric" mask masks out sequences with geometric mean of importance ratios outside the cap rejection_mask_type: null # null, "sequence", "geometric" - # cap for the rejection mask ratio + # used if rejection_mask_type = "geometric" # values around 0.99-1.01 are recommended for "geometric" rejection_mask_type - MoE models may need larger allowed ranges due to higher mismatch - geo_rejection_mask_ratio_cap_high: 1.001 - geo_rejection_mask_ratio_cap_low: 0.999 + geo_rejection_mask_ratio_cap_high: 1.01 + geo_rejection_mask_ratio_cap_low: 0.99 - # sequence level rejection mask ratio cap + # used if rejection_mask_type = "sequence" + # values around 0.5-2.0 are recommended for "sequence" rejection_mask_type sequence_rejection_mask_ratio_cap_high: 2.0 sequence_rejection_mask_ratio_cap_low: 0.5 - # masks out sequences with any token having importance ration far outside an acceptable range + # separate from rejection_mask and tis_ratio_type + # if either is enabled, masks out sequences with any token having importance ratio + # far outside an acceptable range (low and high thresholds) outlier_token_is_threshold_low: 1e-4 outlier_token_is_threshold_high: 100 diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 23745b1d0..d43fa8cb3 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -579,12 +579,6 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks, logprobs, ) - # sanity check for tis (legacy) - if self.cfg.trainer.algorithm.use_tis: - assert ( - rollout_logprobs_tensor is not None - ), "expected non-null rollout logprobs tensor with `trainer.algorithm.use_tis` as `True`" - assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses" # sanity check for rollout_correction rollout_corr = self.cfg.trainer.algorithm.rollout_correction diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 1e9b841be..b4cbabba5 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -325,11 +325,6 @@ def validate_cfg(cfg: DictConfig): "`trainer.algorithm.rollout_correction` doesn't support Sglang backend, please use vLLM" ) - assert cfg.trainer.algorithm.policy_loss_type in [ - "regular", - "dual_clip", - ], "rollout_correction is only implemented for regular and dual_clip policy loss types" - if cfg.trainer.policy.model.lora.rank > 0: # LoRA enabled # Right now: assert generator backend must be vllm, training backend must be fsdp/fsdp2 diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py index 76d7e284f..b3d14c65f 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -256,12 +256,10 @@ def loss_func(logits, data): "final_loss": loss.detach().item(), "policy_loss": policy_loss.detach().item(), "policy_entropy": entropy.detach().item(), - "ppo_clip_ratio": loss_metrics["clip_ratio"], "policy_kl": kl_loss.detach().item(), } for k, v in loss_metrics.items(): - if k != "clip_ratio": - metrics["loss_metrics/" + k] = v + metrics["loss_metrics/" + k] = v return loss, metrics def forward_step(batch_iter, model): diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 898fa8cc7..df73bdf33 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -704,13 +704,11 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> status = { "final_loss": loss.item(), "policy_loss": policy_loss.item(), - "ppo_clip_ratio": loss_metrics["clip_ratio"], "policy_entropy": entropy.item(), "response_length": num_actions, } for k, v in loss_metrics.items(): - if k != "clip_ratio": # we separately name the clip ratio metric - status["loss_metrics/" + k] = v + status["loss_metrics/" + k] = v if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index ad52d74d8..6a0491ccb 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -47,7 +47,6 @@ def test_policy_loss_dual_clip(): "policy_loss_type": "dual_clip", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -101,7 +100,6 @@ def test_policy_loss_cispo(): "policy_loss_type": "cispo", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -180,7 +178,6 @@ def test_policy_loss_reduction_modes(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -193,7 +190,6 @@ def test_policy_loss_reduction_modes(): "policy_loss_type": "regular", "loss_reduction": "sequence_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -267,7 +263,6 @@ def test_policy_loss_reduction_edge_cases(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -280,7 +275,6 @@ def test_policy_loss_reduction_edge_cases(): "policy_loss_type": "regular", "loss_reduction": "sequence_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -367,7 +361,6 @@ def test_gspo_importance_sampling_levels(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -383,7 +376,6 @@ def test_gspo_importance_sampling_levels(): "policy_loss_type": "gspo", "loss_reduction": "sequence_mean", # GSPO recommended reduction "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -514,7 +506,6 @@ def test_clip_cov_policy_loss(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -578,7 +569,6 @@ def test_kl_cov_policy_loss(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": NULL_ROLLOUT_CORR, } ) @@ -1141,7 +1131,6 @@ def test_ppo_policy_loss_with_rollout_correction(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": { "tis_ratio_type": "token", "token_tis_ratio_cap_high": 2.0, @@ -1172,7 +1161,6 @@ def test_ppo_policy_loss_with_rollout_correction(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "use_tis": False, "rollout_correction": { "tis_ratio_type": "null", "rejection_mask_type": "null", diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 599f8a3f4..1d8924eb4 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -550,7 +550,12 @@ def create_test_worker(worker_class): def mock_policy_forward_backward(experience, microbatch_weight): policy_forward_backward_calls.append({"microbatch_weight": microbatch_weight}) - return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length} + return { + "policy_loss": 0.5, + "loss_metrics/clip_ratio": 0.1, + "policy_entropy": 2.0, + "response_length": response_length, + } policy_worker.forward_backward = mock_policy_forward_backward policy_worker.optim_step = MagicMock(return_value=None) diff --git a/skyrl-train/tests/cpu/utils/test_ppo_utils.py b/skyrl-train/tests/cpu/utils/test_ppo_utils.py index fb69ce15e..f9744790c 100644 --- a/skyrl-train/tests/cpu/utils/test_ppo_utils.py +++ b/skyrl-train/tests/cpu/utils/test_ppo_utils.py @@ -17,6 +17,7 @@ AdvantageEstimatorRegistry, register_advantage_estimator, PolicyLossRegistry, + LossMetrics, register_policy_loss, compute_reinforce_plus_plus_outcome_advantage, compute_rloo_outcome_advantage, @@ -376,7 +377,7 @@ def test_policy_loss_registry_specific(): @register_policy_loss("test_policy_decorator") def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None): - return torch.tensor(1.5), 0.3 + return torch.tensor(1.5), LossMetrics(clip_ratio=0.3) # Test decorator worked assert "test_policy_decorator" in PolicyLossRegistry.list_available() @@ -385,14 +386,14 @@ def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mas # Test function execution config = DictConfig({"policy_loss_type": "test_policy_decorator"}) - loss, clip_ratio = retrieved( + loss, loss_metrics = retrieved( log_probs=torch.tensor([[0.1]]), old_log_probs=torch.tensor([[0.2]]), advantages=torch.tensor([[1.0]]), config=config, ) assert loss.item() == 1.5 - assert clip_ratio == 0.3 + assert loss_metrics["clip_ratio"] == 0.3 # Test error message includes "Policy loss" with pytest.raises(ValueError, match="Unknown policy loss"): @@ -413,10 +414,10 @@ def test_registry_cross_ray_process(): # Create test functions def test_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None): - return torch.tensor(2.0), 0.5 + return torch.tensor(2.0), LossMetrics(clip_ratio=0.5) def test_policy_loss_2(log_probs, old_log_probs, advantages, config, loss_mask=None): - return torch.tensor(3.0), 0.6 + return torch.tensor(3.0), LossMetrics(clip_ratio=0.6) def test_advantage_estimator(**kwargs): rewards = kwargs["token_level_rewards"] @@ -432,7 +433,7 @@ def test_ray_registry_access(): policy_loss = PolicyLossRegistry.get("cross_process_test") adv_estimator = AdvantageEstimatorRegistry.get("cross_process_adv_test") - loss, clip_ratio = policy_loss( + loss, loss_metrics = policy_loss( log_probs=torch.tensor([[0.1]]), old_log_probs=torch.tensor([[0.2]]), advantages=torch.tensor([[1.0]]), @@ -444,25 +445,25 @@ def test_ray_registry_access(): response_mask=torch.tensor([[1.0, 1.0]]), index=np.array(["0", "0"]), ) - return loss, clip_ratio, adv, ret + return loss, loss_metrics, adv, ret # Run Ray task - loss, clip_ratio, adv, ret = ray.get(test_ray_registry_access.remote()) + loss, loss_metrics, adv, ret = ray.get(test_ray_registry_access.remote()) assert loss.item() == 2.0 - assert clip_ratio == 0.5 + assert loss_metrics["clip_ratio"] == 0.5 assert adv.shape == torch.Size([1, 2]) assert ret.shape == torch.Size([1, 2]) # test that registration works after ray init as well PolicyLossRegistry.register("cross_process_test_2", test_policy_loss_2) - loss_2, clip_ratio_2 = PolicyLossRegistry.get("cross_process_test_2")( + loss_2, loss_metrics_2 = PolicyLossRegistry.get("cross_process_test_2")( log_probs=torch.tensor([[0.1]]), old_log_probs=torch.tensor([[0.2]]), advantages=torch.tensor([[1.0]]), config=DictConfig({"policy_loss_type": "cross_process_test_2"}), ) assert loss_2.item() == 3.0 - assert clip_ratio_2 == 0.6 + assert loss_metrics_2["clip_ratio"] == 0.6 finally: PolicyLossRegistry.reset() AdvantageEstimatorRegistry.reset() diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 592a0518b..7ecd9eb9b 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -501,7 +501,7 @@ async def test_megatron_train( assert isinstance(result, dict), "Result should be a dictionary of training stats" assert "policy_loss" in result assert "policy_lr" in result - assert "ppo_clip_ratio" in result + assert "loss_metrics/clip_ratio" in result assert "policy_entropy" in result for k, v in result.items(): assert isinstance(v, (int, float)), f"{k} should be an int or float" @@ -539,7 +539,14 @@ async def test_megatron_train( print("\n\n") print("fsdp results: ", results_fsdp[0]) - keys_to_compare = ["policy_loss", "policy_lr", "ppo_clip_ratio", "policy_entropy", "policy_kl", "final_loss"] + keys_to_compare = [ + "policy_loss", + "policy_lr", + "loss_metrics/clip_ratio", + "policy_entropy", + "policy_kl", + "final_loss", + ] for i, result in enumerate(results_fsdp): for k in keys_to_compare: if k == "policy_entropy": @@ -605,7 +612,7 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node) assert isinstance(result, dict), "Result should be a dictionary of training stats" assert "policy_loss" in result assert "policy_lr" in result - assert "ppo_clip_ratio" in result + assert "loss_metrics/clip_ratio" in result assert "policy_entropy" in result for k, v in result.items(): assert isinstance(v, (int, float)), f"{k} should be an int or float" @@ -641,7 +648,14 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node) print("\n\n") print("megatron results dp: ", results_megatron_dp) - keys_to_compare = ["policy_loss", "policy_lr", "ppo_clip_ratio", "policy_entropy", "policy_kl", "raw_grad_norm"] + keys_to_compare = [ + "policy_loss", + "policy_lr", + "loss_metrics/clip_ratio", + "policy_entropy", + "policy_kl", + "raw_grad_norm", + ] for i, result in enumerate(results_megatron_dp): for k in keys_to_compare: assert isinstance(result[k], (int, float)), f"{k} should be an int or float" diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py index ff5343320..5131ae446 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py @@ -81,7 +81,7 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ "policy_loss", "policy_update_steps", "policy_lr", - "ppo_clip_ratio", + "loss_metrics/clip_ratio", "policy_entropy", "final_loss", ] diff --git a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py index e81103434..f72f87dfa 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py @@ -73,7 +73,7 @@ async def test_policy_forward_backward_and_optim_step(ray_init_fixture, cfg, pac for result in results: assert isinstance(result, dict), "Result should be a dictionary of training stats" assert "policy_loss" in result - assert "ppo_clip_ratio" in result + assert "loss_metrics/clip_ratio" in result assert "policy_entropy" in result for k, v in result.items(): assert isinstance(v, (int, float)), f"{k} should be an int or float" From 7e83c1002405df0ad9fa1a8426abf93b3cee3113 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 20:02:39 +0000 Subject: [PATCH 12/23] x --- skyrl-train/examples/gsm8k/run_gsm8k.sh | 11 ---------- .../examples/gsm8k/run_gsm8k_tis_geo_rs.sh | 0 .../skyrl_train/config/ppo_base_config.yaml | 4 +++- skyrl-train/skyrl_train/utils/ppo_utils.py | 22 +++++++++---------- 4 files changed, 14 insertions(+), 23 deletions(-) delete mode 100644 skyrl-train/examples/gsm8k/run_gsm8k_tis_geo_rs.sh diff --git a/skyrl-train/examples/gsm8k/run_gsm8k.sh b/skyrl-train/examples/gsm8k/run_gsm8k.sh index 95c54710e..55328a706 100755 --- a/skyrl-train/examples/gsm8k/run_gsm8k.sh +++ b/skyrl-train/examples/gsm8k/run_gsm8k.sh @@ -17,21 +17,10 @@ set -x : "${INFERENCE_BACKEND:=vllm}" # : "${INFERENCE_BACKEND:=sglang}" - -# rollout correction parameters -TIS_RATIO_TYPE="sequence" -REJECTION_MASK_TYPE="geometric" -GEO_REJECTION_MASK_RATIO_CAP_HIGH=1.01 -GEO_REJECTION_MASK_RATIO_CAP_LOW=0.99 - uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ trainer.algorithm.advantage_estimator="grpo" \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_RATIO_TYPE \ - trainer.algorithm.rollout_correction.rejection_mask_type=$REJECTION_MASK_TYPE \ - trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high=$GEO_REJECTION_MASK_RATIO_CAP_HIGH \ - trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low=$GEO_REJECTION_MASK_RATIO_CAP_LOW \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/gsm8k/run_gsm8k_tis_geo_rs.sh b/skyrl-train/examples/gsm8k/run_gsm8k_tis_geo_rs.sh deleted file mode 100644 index e69de29bb..000000000 diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index dfbd04eae..376546d87 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -126,7 +126,9 @@ trainer: tis_imp_ratio_cap: -1.0 use_tis: false - # reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + # references + # - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + # - https://fengyao.notion.site/off-policy-rl rollout_correction: # type of importance sampling ratio to use for ppo loss correction # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy}) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index ed50480d9..6f6bfefac 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -589,7 +589,7 @@ def compute_tis_ratio( # Compute proportion of tokens capped tokens_capped = (token_tis_ratio > token_tis_ratio_cap) & (loss_mask > 0) total_tokens = (loss_mask > 0).sum() - metrics["tis_token_capped_frac"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item() + metrics["tis_token_capped_ratio"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item() return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap), metrics elif tis_ratio_type == "sequence": # Compute sequence-level importance ratio as product of token ratios (sum of log ratios) @@ -599,7 +599,7 @@ def compute_tis_ratio( # Compute proportion of sequences capped num_sequences = seq_tis_ratio.shape[0] seqs_capped = (seq_tis_ratio > seq_tis_ratio_cap).sum() - metrics["tis_seq_capped_frac"] = (seqs_capped / num_sequences).detach().item() + metrics["tis_seq_capped_ratio"] = (seqs_capped / num_sequences).detach().item() return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap), metrics else: raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") @@ -651,9 +651,9 @@ def compute_outlier_token_mask( # Sequence has any token under low threshold seq_has_under_low = token_under_low.any(dim=-1) - metrics["outlier_seq_masked_frac"] = ((~all_tokens_valid.squeeze(-1)).sum() / num_sequences).detach().item() - metrics["outlier_seq_over_high_frac"] = (seq_has_over_high.sum() / num_sequences).detach().item() - metrics["outlier_seq_under_low_frac"] = (seq_has_under_low.sum() / num_sequences).detach().item() + metrics["outlier_seq_masked_ratio"] = ((~all_tokens_valid.squeeze(-1)).sum() / num_sequences).detach().item() + metrics["outlier_seq_over_high_ratio"] = (seq_has_over_high.sum() / num_sequences).detach().item() + metrics["outlier_seq_under_low_ratio"] = (seq_has_under_low.sum() / num_sequences).detach().item() return all_tokens_valid.float(), metrics @@ -699,9 +699,9 @@ def compute_rejection_mask( geo_rejection_mask = ~seq_over_high & ~seq_under_low num_sequences = float(geo_mean_ratio.shape[0]) - metrics["rejection_seq_masked_frac"] = ((~geo_rejection_mask).sum() / num_sequences).detach().item() - metrics["rejection_seq_over_high_frac"] = (seq_over_high.sum() / num_sequences).detach().item() - metrics["rejection_seq_under_low_frac"] = (seq_under_low.sum() / num_sequences).detach().item() + metrics["rejection_seq_masked_ratio"] = ((~geo_rejection_mask).sum() / num_sequences).detach().item() + metrics["rejection_seq_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["rejection_seq_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() return geo_rejection_mask.float(), metrics elif rejection_mask_type == "sequence": @@ -715,9 +715,9 @@ def compute_rejection_mask( seq_in_bounds = ~seq_over_high & ~seq_under_low num_sequences = float(seq_tis_ratio.shape[0]) - metrics["rejection_seq_masked_frac"] = ((~seq_in_bounds).sum() / num_sequences).detach().item() - metrics["rejection_seq_over_high_frac"] = (seq_over_high.sum() / num_sequences).detach().item() - metrics["rejection_seq_under_low_frac"] = (seq_under_low.sum() / num_sequences).detach().item() + metrics["rejection_seq_masked_ratio"] = ((~seq_in_bounds).sum() / num_sequences).detach().item() + metrics["rejection_seq_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["rejection_seq_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() return seq_in_bounds.float(), metrics else: From cf042fcd42d9704a96807302d27fb7c3d6b46be9 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 23:17:44 +0000 Subject: [PATCH 13/23] renaming --- .../docs/algorithms/custom_algorithms.rst | 2 +- skyrl-train/docs/configuration/config.rst | 9 +- skyrl-train/docs/examples/flash_rl.rst | 4 +- .../run_dapo_gsm8k_flashrl_0.5b_fp8.sh | 4 +- .../run_dapo_gsm8k_flashrl_0.5b_int8.sh | 4 +- .../run_dapo_gsm8k_flashrl_32b_int8.sh | 4 +- .../run_dapo_repro_flashrl_0.5b_int8.sh | 4 +- .../run_dapo_repro_flashrl_32b_int8.sh | 4 +- .../examples/fully_async/async_run_gsm8k.sh | 4 +- skyrl-train/examples/gsm8k/run_gsm8k.sh | 2 +- .../run_megatron_dapo_qwen3_30b_a3b.sh | 4 +- .../run_megatron_dapo_qwen3_30b_a3b_lora.sh | 22 +- .../megatron/run_megatron_dapo_qwen3_4b.sh | 4 +- .../run_megatron_dapo_qwen3_4b_lora.sh | 4 +- .../run_megatron_lora_qwen3-30b-a3b.sh | 4 +- skyrl-train/examples/search/run_search.sh | 4 +- .../search/run_search_conversation_format.sh | 4 +- .../run_skyrl_sql_megatron_lora.sh | 4 +- .../examples/tis_correction/run_dapo_tis.sh | 4 +- .../skyrl_train/config/ppo_base_config.yaml | 36 +-- skyrl-train/skyrl_train/trainer.py | 12 +- skyrl-train/skyrl_train/utils/ppo_utils.py | 194 +++++++------ skyrl-train/skyrl_train/utils/utils.py | 32 +-- .../tests/cpu/algorithms/test_losses.py | 266 +++++++++--------- .../tests/gpu/gpu_ci/test_ppo_train.py | 8 +- 25 files changed, 323 insertions(+), 320 deletions(-) diff --git a/skyrl-train/docs/algorithms/custom_algorithms.rst b/skyrl-train/docs/algorithms/custom_algorithms.rst index f7495b2d0..174cf3964 100644 --- a/skyrl-train/docs/algorithms/custom_algorithms.rst +++ b/skyrl-train/docs/algorithms/custom_algorithms.rst @@ -54,7 +54,7 @@ Similarly, you can register custom policy loss functions: def compute_reinforce_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None): # Your custom policy loss implementation (like REINFORCE) loss = (-log_probs * advantages).mean() - # return loss and clip ratio + # return loss and loss metrics return loss, LossMetrics(clip_ratio=0.0) Registry Ray Distribution diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 0c8e7baf8..c13c1f62e 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -491,14 +491,7 @@ Algorithm Configuration Rollout Correction Configuration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- ``algorithm.rollout_correction``: Rollout correction configuration. - - ``algorithm.rollout_correction.tis_ratio_type``: Type of importance ratio to use for rollout correction. Options include: ``token``, ``sequence``, or ``null``. - - ``algorithm.rollout_correction.token_tis_ratio_cap_high``: Cap for the importance ratio for ``token`` tis_ratio_type. - - ``algorithm.rollout_correction.sequence_tis_ratio_cap_high``: Cap for the importance ratio for ``sequence`` tis_ratio_type. - - ``algorithm.rollout_correction.rejection_mask_type``: Type of rejection mask to use. Options include: ``sequence``, ``geometric``, or ``null``. - - ``algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high``: Cap for the rejection mask ratio for ``geometric`` rejection_mask_type. - - ``algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low``: Cap for the rejection mask ratio for ``geometric`` rejection_mask_type. - - ``algorithm.rollout_correction.sequence_rejection_mask_ratio_cap_high``: Cap for the rejection mask ratio for ``sequence`` rejection_mask_type. +- ``algorithm.off_policy_correction``: Off policy correction configuration. Policy Loss Formulation ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/skyrl-train/docs/examples/flash_rl.rst b/skyrl-train/docs/examples/flash_rl.rst index acec33b38..210859f5f 100644 --- a/skyrl-train/docs/examples/flash_rl.rst +++ b/skyrl-train/docs/examples/flash_rl.rst @@ -65,8 +65,8 @@ We highlight some important training parameters configured for FlashRL from our uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 -m examples.flash_rl.main_dapo_flashrl \ ... - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ generator.sampling_params.logprobs=0 \ ... diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh index c9a3b500b..bd3844397 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh @@ -55,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with v generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.algorithm.tis_imp_ratio_cap=2.0 \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh index f5d90cfc9..cdfa54ef9 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh @@ -55,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.algorithm.tis_imp_ratio_cap=2.0 \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh index ca9375b0e..383f87ba8 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_32b_int8.sh @@ -55,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-32B" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh index a44845086..70db7e13f 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_0.5b_int8.sh @@ -52,8 +52,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 -- trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh index 7b0122c6e..83842dfc4 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_repro_flashrl_32b_int8.sh @@ -57,8 +57,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-32B" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/fully_async/async_run_gsm8k.sh b/skyrl-train/examples/fully_async/async_run_gsm8k.sh index 8cc7ff9f7..2b2412875 100644 --- a/skyrl-train/examples/fully_async/async_run_gsm8k.sh +++ b/skyrl-train/examples/fully_async/async_run_gsm8k.sh @@ -39,8 +39,8 @@ uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \ trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \ trainer.algorithm.advantage_estimator="grpo" \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=false \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/gsm8k/run_gsm8k.sh b/skyrl-train/examples/gsm8k/run_gsm8k.sh index 55328a706..6b558da62 100755 --- a/skyrl-train/examples/gsm8k/run_gsm8k.sh +++ b/skyrl-train/examples/gsm8k/run_gsm8k.sh @@ -10,7 +10,7 @@ set -x # You can override the default values with e.g.: `NUM_GPUS=1 bash examples/gsm8k/run_gsm8k.sh`. -: "${DATA_DIR:="/mnt/cluster_storage/data/gsm8k"}" +: "${DATA_DIR:="$HOME/data/gsm8k"}" : "${NUM_GPUS:=4}" : "${LOGGER:=wandb}" # change to "console" to print to stdout diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh index 584234160..3c10e4f18 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b.sh @@ -83,8 +83,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh index 61514f9b4..035d5a88e 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh @@ -59,9 +59,9 @@ TIS_IMP_RATIO_CAP=3.0 # rollout correction parameters TIS_RATIO_TYPE="sequence" -REJECTION_MASK_TYPE="geometric" -GEO_REJECTION_MASK_RATIO_CAP_HIGH=1.05 -GEO_REJECTION_MASK_RATIO_CAP_LOW=0.95 +sequence_mask_metric="geometric" +geo_mask_high=1.05 +geo_mask_low=0.95 uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ data.train_data="['$TRAIN_FILE']" \ @@ -93,11 +93,11 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_RATIO_TYPE \ - trainer.algorithm.rollout_correction.sequence_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ - trainer.algorithm.rollout_correction.rejection_mask_type=$REJECTION_MASK_TYPE \ - trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high=$GEO_REJECTION_MASK_RATIO_CAP_HIGH \ - trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low=$GEO_REJECTION_MASK_RATIO_CAP_LOW \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_RATIO_TYPE \ + trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.sequence_mask_metric=$sequence_mask_metric \ + trainer.algorithm.off_policy_correction.geo_mask_high=$geo_mask_high \ + trainer.algorithm.off_policy_correction.geo_mask_low=$geo_mask_low \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ @@ -127,10 +127,10 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ generator.gpu_memory_utilization=0.7 \ trainer.logger="$LOGGER" \ trainer.project_name="dapo_aime" \ - trainer.run_name="dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${REJECTION_MASK_TYPE}_${GEO_REJECTION_MASK_RATIO_CAP_HIGH}_${GEO_REJECTION_MASK_RATIO_CAP_LOW}" \ - trainer.export_path="$HOME/exports/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${REJECTION_MASK_TYPE}_${GEO_REJECTION_MASK_RATIO_CAP_HIGH}_${GEO_REJECTION_MASK_RATIO_CAP_LOW}" \ + trainer.run_name="dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${sequence_mask_metric}_${geo_mask_high}_${geo_mask_low}" \ + trainer.export_path="$HOME/exports/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${sequence_mask_metric}_${geo_mask_high}_${geo_mask_low}" \ trainer.hf_save_interval=300 \ trainer.resume_mode=latest \ trainer.max_ckpts_to_keep=3 \ - trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${REJECTION_MASK_TYPE}_${GEO_REJECTION_MASK_RATIO_CAP_HIGH}_${GEO_REJECTION_MASK_RATIO_CAP_LOW}" \ + trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${sequence_mask_metric}_${geo_mask_high}_${geo_mask_low}" \ $@ \ No newline at end of file diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh index 8e9393445..86a6a7ec6 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh @@ -80,8 +80,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.epochs=20 \ trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \ trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \ diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh index 9486a9d9f..801c69ab0 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b_lora.sh @@ -85,8 +85,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \ diff --git a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh index b27f8f869..efe6eaf4c 100644 --- a/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh +++ b/skyrl-train/examples/megatron/run_megatron_lora_qwen3-30b-a3b.sh @@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ trainer.ref.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ trainer.ref.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.lora.rank=$LORA_RANK \ trainer.policy.model.lora.alpha=$LORA_ALPHA \ trainer.policy.model.lora.init_method=$LORA_A_INIT_METHOD \ diff --git a/skyrl-train/examples/search/run_search.sh b/skyrl-train/examples/search/run_search.sh index 3203525b1..7166f6283 100755 --- a/skyrl-train/examples/search/run_search.sh +++ b/skyrl-train/examples/search/run_search.sh @@ -23,8 +23,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.num_warmup_steps=94 \ trainer.algorithm.use_kl_loss=true \ trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/search/run_search_conversation_format.sh b/skyrl-train/examples/search/run_search_conversation_format.sh index 3679edb1e..285bafd29 100755 --- a/skyrl-train/examples/search/run_search_conversation_format.sh +++ b/skyrl-train/examples/search/run_search_conversation_format.sh @@ -30,8 +30,8 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.num_warmup_steps=94 \ trainer.algorithm.use_kl_loss=true \ trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh index eaf4dd4fe..47b87d85c 100644 --- a/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh +++ b/skyrl-train/examples/text_to_sql/run_skyrl_sql_megatron_lora.sh @@ -63,8 +63,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.policy.optimizer_config.lr=3.0e-5 \ trainer.policy_mini_batch_size=256 \ trainer.algorithm.use_kl_loss=false \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.ckpt_interval=60 \ trainer.hf_save_interval=30 \ trainer.dump_data_batch=true \ diff --git a/skyrl-train/examples/tis_correction/run_dapo_tis.sh b/skyrl-train/examples/tis_correction/run_dapo_tis.sh index 04bc4cd8a..674d0aaa3 100644 --- a/skyrl-train/examples/tis_correction/run_dapo_tis.sh +++ b/skyrl-train/examples/tis_correction/run_dapo_tis.sh @@ -55,8 +55,8 @@ uv run --isolated --extra vllm -m examples.tis_correction.main_tis_dapo \ generator.eval_sampling_params.top_p=$EVAL_TOP_P \ trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ - trainer.algorithm.rollout_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.rollout_correction.token_tis_ratio_cap_high=$TIS_IMP_RATIO_CAP \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 376546d87..e7ed8db9a 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -121,41 +121,41 @@ trainer: # dual clip parameters clip_ratio_c: 3.0 - # To be deprecated in favor of rollout_correction.tis_ratio_type = "token" - # and "token_tis_ratio_cap_high" + # To be deprecated in favor of off_policy_correction.tis_ratio_type = "token" + # and "token_tis_ratio_clip_high" tis_imp_ratio_cap: -1.0 use_tis: false # references # - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md # - https://fengyao.notion.site/off-policy-rl - rollout_correction: + off_policy_correction: # type of importance sampling ratio to use for ppo loss correction # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy}) tis_ratio_type: null # null, "token", "sequence" # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type - token_tis_ratio_cap_high: 2.0 + token_tis_ratio_clip_high: 2.0 # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type - sequence_tis_ratio_cap_high: 5.0 + sequence_tis_ratio_clip_high: 5.0 # method of masking out sequences with cumulative importance sampling ratios outside the cap - # "sequence" mask masks out sequences with product of importance ratios outside the cap - # "geometric" mask masks out sequences with geometric mean of importance ratios outside the cap - rejection_mask_type: null # null, "sequence", "geometric" + # "product" masks out sequences with product of importance ratios outside the cap + # "geometric" masks out sequences with geometric mean of importance ratios outside the cap + sequence_mask_metric: null # null, "product", "geometric" - # used if rejection_mask_type = "geometric" - # values around 0.99-1.01 are recommended for "geometric" rejection_mask_type - MoE models may need larger allowed ranges due to higher mismatch - geo_rejection_mask_ratio_cap_high: 1.01 - geo_rejection_mask_ratio_cap_low: 0.99 + # used if sequence_mask_metric = "geometric" + # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch + geo_mask_high: 1.01 + geo_mask_low: 0.99 - # used if rejection_mask_type = "sequence" - # values around 0.5-2.0 are recommended for "sequence" rejection_mask_type - sequence_rejection_mask_ratio_cap_high: 2.0 - sequence_rejection_mask_ratio_cap_low: 0.5 + # used if sequence_mask_metric = "product" + # values around 0.5-2.0 are recommended for "sequence" sequence_mask_metric + product_mask_high: 2.0 + product_mask_low: 0.5 - # separate from rejection_mask and tis_ratio_type - # if either is enabled, masks out sequences with any token having importance ratio + # separate from sequence_mask_metric and tis_ratio_type + # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio # far outside an acceptable range (low and high thresholds) outlier_token_is_threshold_low: 1e-4 outlier_token_is_threshold_high: 100 diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index d43fa8cb3..9b27112c4 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -580,14 +580,14 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis logprobs, ) - # sanity check for rollout_correction - rollout_corr = self.cfg.trainer.algorithm.rollout_correction - tis_ratio_type = rollout_corr.tis_ratio_type - rejection_mask_type = rollout_corr.rejection_mask_type - if tis_ratio_type is not None or rejection_mask_type is not None: + # sanity check for off_policy_correction + off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction + tis_ratio_type = off_policy_correction.tis_ratio_type + sequence_mask_metric = off_policy_correction.sequence_mask_metric + if tis_ratio_type is not None or sequence_mask_metric is not None: assert ( rollout_logprobs_tensor is not None - ), "expected non-null rollout logprobs tensor when rollout_correction is enabled" + ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled" assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses" training_input = TrainingInputBatch( diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 6f6bfefac..e738c9fea 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -19,7 +19,7 @@ from collections import defaultdict from enum import StrEnum from functools import wraps -from typing import Callable, List, Literal, Optional, Tuple, Union, TypedDict +from typing import Callable, List, Literal, Optional, Tuple, Union, TypedDict, NotRequired import numpy as np import ray @@ -196,8 +196,8 @@ def ppo_critic_loss( return 0.5 * loss, clipfrac -class LossMetrics(TypedDict): - clip_ratio: float +class LossMetrics(TypedDict, total=False): + clip_ratio: NotRequired[float] # Shared registry actor class for both policy loss and advantage estimator registries @@ -561,7 +561,7 @@ def compute_tis_ratio( rollout_logprobs: torch.Tensor, loss_mask: torch.Tensor, tis_ratio_type: str, - rollout_corr: DictConfig, + off_policy_correction: DictConfig, ) -> torch.Tensor: """ Compute truncated importance sampling (TIS) ratio for rollout correction. @@ -571,7 +571,7 @@ def compute_tis_ratio( rollout_logprobs: Log probabilities from the rollout policy. loss_mask: Mask indicating valid tokens. tis_ratio_type: Type of TIS ratio ("token" or "sequence"). - rollout_corr: Rollout correction config containing cap values. + off_policy_correction: Off-policy correction config containing cap values. Returns: TIS ratio tensor to multiply with the loss. @@ -585,21 +585,21 @@ def compute_tis_ratio( metrics = {} if tis_ratio_type == "token": - token_tis_ratio_cap = rollout_corr.token_tis_ratio_cap_high + token_tis_ratio_cap = off_policy_correction.token_tis_ratio_clip_high # Compute proportion of tokens capped tokens_capped = (token_tis_ratio > token_tis_ratio_cap) & (loss_mask > 0) total_tokens = (loss_mask > 0).sum() - metrics["tis_token_capped_ratio"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item() + metrics["tis_token_clip_high_ratio"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item() return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap), metrics elif tis_ratio_type == "sequence": # Compute sequence-level importance ratio as product of token ratios (sum of log ratios) seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) - seq_tis_ratio_cap = rollout_corr.sequence_tis_ratio_cap_high + seq_tis_ratio_cap = off_policy_correction.sequence_tis_ratio_clip_high # Compute proportion of sequences capped num_sequences = seq_tis_ratio.shape[0] seqs_capped = (seq_tis_ratio > seq_tis_ratio_cap).sum() - metrics["tis_seq_capped_ratio"] = (seqs_capped / num_sequences).detach().item() + metrics["tis_seq_clip_high_ratio"] = (seqs_capped / num_sequences).detach().item() return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap), metrics else: raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") @@ -609,20 +609,20 @@ def compute_outlier_token_mask( old_log_probs: torch.Tensor, rollout_logprobs: torch.Tensor, loss_mask: torch.Tensor, - rollout_corr: DictConfig, + off_policy_correction: DictConfig, ) -> torch.Tensor: """ - Compute outlier token mask that rejects sequences with any token having + Compute outlier token mask that masks out sequences with any token having importance ratio outside acceptable bounds. - This is applied independently of TIS ratio type or rejection mask type, - whenever rollout correction is enabled. + This is applied independently of TIS ratio type or sequence mask type, + whenever off policy correction is enabled. Args: old_log_probs: Log probabilities from the old policy (before update). rollout_logprobs: Log probabilities from the rollout policy. loss_mask: Mask indicating valid tokens. - rollout_corr: Rollout correction config containing threshold values. + off_policy_correction: Off-policy correction config containing threshold values. Returns: Tuple of (outlier_mask, metrics): @@ -635,8 +635,8 @@ def compute_outlier_token_mask( token_tis_ratio = _safe_exp_delta(token_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) # Check per-token bounds - token_mask_low = rollout_corr.outlier_token_is_threshold_low - token_mask_high = rollout_corr.outlier_token_is_threshold_high + token_mask_low = off_policy_correction.outlier_token_is_threshold_low + token_mask_high = off_policy_correction.outlier_token_is_threshold_high token_over_high = (token_tis_ratio > token_mask_high) & (loss_mask > 0) token_under_low = (token_tis_ratio < token_mask_low) & (loss_mask > 0) token_in_bounds = ~token_over_high & ~token_under_low @@ -658,15 +658,15 @@ def compute_outlier_token_mask( return all_tokens_valid.float(), metrics -def compute_rejection_mask( +def compute_sequence_mask( old_log_probs: torch.Tensor, rollout_logprobs: torch.Tensor, loss_mask: torch.Tensor, - rejection_mask_type: str, - rollout_corr: DictConfig, + sequence_mask_metric: str, + off_policy_correction: DictConfig, ) -> Tuple[torch.Tensor, dict]: """ - Compute rejection mask for rollout correction. + Compute sequence mask for off policy correction. This masks out sequences with importance ratios that fall outside acceptable bounds, helping to filter out off-policy samples that may destabilize training. @@ -675,66 +675,66 @@ def compute_rejection_mask( old_log_probs: Log probabilities from the old policy (before update). rollout_logprobs: Log probabilities from the rollout policy. loss_mask: Mask indicating valid tokens. - rejection_mask_type: Type of rejection mask ("geometric" or "sequence"). - rollout_corr: Rollout correction config containing cap values. + sequence_mask_metric: Metric to use for sequence masking ("geometric" or "product"). + off_policy_correction: Off-policy correction config containing cap values. Returns: - Tuple of (rejection_mask, metrics): - - rejection_mask: Tensor (float) to multiply with the loss + Tuple of (sequence_mask, metrics): + - sequence_mask: Tensor (float) to multiply with the loss - metrics: Dict with masking statistics """ # Compute token-level importance ratio token_tis_log_ratio = old_log_probs - rollout_logprobs metrics = {} - if rejection_mask_type == "geometric": + if sequence_mask_metric == "geometric": # Compute geometric mean of importance ratios per sequence num_tokens = loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) geo_mean_ratio = _safe_exp_delta(seq_tis_log_ratio / num_tokens, clip=20.0, out_dtype=old_log_probs.dtype) - geo_cap_high = rollout_corr.geo_rejection_mask_ratio_cap_high - geo_cap_low = rollout_corr.geo_rejection_mask_ratio_cap_low + geo_cap_high = off_policy_correction.geo_mask_high + geo_cap_low = off_policy_correction.geo_mask_low seq_over_high = geo_mean_ratio > geo_cap_high seq_under_low = geo_mean_ratio < geo_cap_low - geo_rejection_mask = ~seq_over_high & ~seq_under_low + geo_sequence_mask = ~seq_over_high & ~seq_under_low num_sequences = float(geo_mean_ratio.shape[0]) - metrics["rejection_seq_masked_ratio"] = ((~geo_rejection_mask).sum() / num_sequences).detach().item() - metrics["rejection_seq_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() - metrics["rejection_seq_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() + metrics["geo_sequence_mask_masked_ratio"] = ((~geo_sequence_mask).sum() / num_sequences).detach().item() + metrics["geo_sequence_mask_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["geo_sequence_mask_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() - return geo_rejection_mask.float(), metrics - elif rejection_mask_type == "sequence": + return geo_sequence_mask.float(), metrics + elif sequence_mask_metric == "product": # Mask out sequences with product of importance ratios outside the cap seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) seq_tis_ratio = _safe_exp_delta(seq_tis_log_ratio, clip=20.0, out_dtype=old_log_probs.dtype) - seq_cap_high = rollout_corr.sequence_rejection_mask_ratio_cap_high - seq_cap_low = rollout_corr.sequence_rejection_mask_ratio_cap_low + seq_cap_high = off_policy_correction.product_mask_high + seq_cap_low = off_policy_correction.product_mask_low seq_over_high = seq_tis_ratio > seq_cap_high seq_under_low = seq_tis_ratio < seq_cap_low seq_in_bounds = ~seq_over_high & ~seq_under_low num_sequences = float(seq_tis_ratio.shape[0]) - metrics["rejection_seq_masked_ratio"] = ((~seq_in_bounds).sum() / num_sequences).detach().item() - metrics["rejection_seq_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() - metrics["rejection_seq_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() + metrics["product_sequence_mask_masked_ratio"] = ((~seq_in_bounds).sum() / num_sequences).detach().item() + metrics["product_sequence_mask_over_high_ratio"] = (seq_over_high.sum() / num_sequences).detach().item() + metrics["product_sequence_mask_under_low_ratio"] = (seq_under_low.sum() / num_sequences).detach().item() return seq_in_bounds.float(), metrics else: - raise ValueError(f"Unknown rejection_mask_type: {rejection_mask_type}") + raise ValueError(f"Unknown sequence_mask_metric: {sequence_mask_metric}") -def apply_rollout_correction( +def apply_off_policy_correction( loss: torch.Tensor, old_log_probs: torch.Tensor, rollout_logprobs: torch.Tensor, loss_mask: torch.Tensor, - rollout_corr: DictConfig, + off_policy_correction: DictConfig, ) -> torch.Tensor: """ - Apply rollout correction to the loss using TIS ratio, rejection mask, and outlier token mask. + Apply off policy correction to the loss using TIS ratio, sequence mask, and outlier token mask. - This is a convenience function that combines compute_tis_ratio, compute_rejection_mask, + This is a convenience function that combines compute_tis_ratio, compute_sequence_mask, and compute_outlier_token_mask. Args: @@ -742,7 +742,7 @@ def apply_rollout_correction( old_log_probs: Log probabilities from the old policy (before update). rollout_logprobs: Log probabilities from the rollout policy. loss_mask: Mask indicating valid tokens. - rollout_corr: Rollout correction config. + off_policy_correction: Off-policy correction config. Returns: Corrected loss tensor. @@ -751,16 +751,16 @@ def apply_rollout_correction( - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md - https://fengyao.notion.site/off-policy-rl """ - tis_ratio_type = rollout_corr.tis_ratio_type - rejection_mask_type = rollout_corr.rejection_mask_type + tis_ratio_type = off_policy_correction.tis_ratio_type + sequence_mask_metric = off_policy_correction.sequence_mask_metric # Check if TIS ratio correction is enabled - apply_tis = tis_ratio_type is not None and tis_ratio_type != "null" - # Check if rejection mask is enabled - apply_rejection = rejection_mask_type is not None and rejection_mask_type != "null" + apply_tis = tis_ratio_type is not None + # Check if sequence mask is enabled + apply_sequence_mask = sequence_mask_metric is not None # Early return if no correction needed - if not apply_tis and not apply_rejection: + if not apply_tis and not apply_sequence_mask: return loss, {}, loss_mask is_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=old_log_probs.dtype) @@ -770,27 +770,29 @@ def apply_rollout_correction( metrics["is_ratio_max"] = (is_ratio * loss_mask).max().detach().item() metrics["is_ratio_min"] = (is_ratio * loss_mask).min().detach().item() - # Apply outlier token mask whenever rollout correction is enabled + # Apply outlier token mask whenever off policy correction is enabled # This rejects sequences with any token having importance ratio outside acceptable bounds - outlier_mask, outlier_metrics = compute_outlier_token_mask(old_log_probs, rollout_logprobs, loss_mask, rollout_corr) + outlier_mask, outlier_metrics = compute_outlier_token_mask( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + ) loss_mask = loss_mask * outlier_mask metrics.update(outlier_metrics) # Apply TIS ratio if enabled if apply_tis: tis_ratio, tis_metrics = compute_tis_ratio( - old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, rollout_corr + old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, off_policy_correction ) loss = loss * tis_ratio metrics.update(tis_metrics) - # Apply rejection mask if enabled - if apply_rejection: - rejection_mask, rejection_metrics = compute_rejection_mask( - old_log_probs, rollout_logprobs, loss_mask, rejection_mask_type, rollout_corr + # Apply sequence mask if enabled + if apply_sequence_mask: + sequence_mask, sequence_mask_metrics = compute_sequence_mask( + old_log_probs, rollout_logprobs, loss_mask, sequence_mask_metric, off_policy_correction ) - loss_mask = loss_mask * rejection_mask - metrics.update(rejection_metrics) + loss_mask = loss_mask * sequence_mask + metrics.update(sequence_mask_metrics) return loss, metrics, loss_mask @@ -827,12 +829,12 @@ def ppo_policy_loss( loss_metrics = LossMetrics(clip_ratio=clip_ratio) # apply rollout correction - rollout_corr = config.rollout_correction - if rollout_corr is not None and rollout_logprobs is not None: - loss, rollout_correction_metrics, loss_mask = apply_rollout_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + off_policy_correction = config.off_policy_correction + if rollout_logprobs is not None: + loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) - loss_metrics.update(rollout_correction_metrics) + loss_metrics.update(off_policy_correction_metrics) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) return loss, loss_metrics @@ -898,13 +900,13 @@ def gate_function(x, tau): loss = -gates * advantages # apply rollout correction - rollout_corr = config.rollout_correction + off_policy_correction = config.off_policy_correction loss_metrics = LossMetrics(clip_ratio=0.0) - if rollout_corr is not None and rollout_logprobs is not None: - loss, rollout_correction_metrics, loss_mask = apply_rollout_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + if rollout_logprobs is not None: + loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) - loss_metrics.update(rollout_correction_metrics) + loss_metrics.update(off_policy_correction_metrics) # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) @@ -963,19 +965,21 @@ def gspo_policy_loss( surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages loss = -torch.min(surr1, surr2) - # apply rollout correction - rollout_corr = config.rollout_correction - if rollout_corr is not None and rollout_logprobs is not None: - loss, loss_metrics, loss_mask = apply_rollout_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr - ) - # Compute clipping ratio for monitoring clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() + # apply rollout correction + loss_metrics = LossMetrics(clip_ratio=clip_ratio) + off_policy_correction = config.off_policy_correction + if rollout_logprobs is not None: + loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + ) + loss_metrics.update(off_policy_correction_metrics) + loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) - return loss, LossMetrics(clip_ratio=clip_ratio) + return loss, loss_metrics @register_policy_loss(PolicyLossType.CISPO) @@ -1004,14 +1008,16 @@ def compute_policy_loss_cispo( clip_ratio = masked_mean(is_clipped.float(), loss_mask).mean().detach().item() # apply rollout correction - rollout_corr = config.rollout_correction - if rollout_corr is not None and rollout_logprobs is not None: - loss, loss_metrics, loss_mask = apply_rollout_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + off_policy_correction = config.off_policy_correction + loss_metrics = LossMetrics(clip_ratio=clip_ratio) + if rollout_logprobs is not None: + loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( + loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) + loss_metrics.update(off_policy_correction_metrics) loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) - return loss, LossMetrics(clip_ratio=clip_ratio) + return loss, loss_metrics @register_policy_loss(PolicyLossType.CLIP_COV) @@ -1073,11 +1079,13 @@ def compute_policy_loss_clip_cov( pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr # apply rollout correction - rollout_corr = config.rollout_correction - if rollout_corr is not None and rollout_logprobs is not None: - pg_losses, loss_metrics, loss_mask = apply_rollout_correction( - pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + off_policy_correction = config.off_policy_correction + loss_metrics = LossMetrics(clip_ratio=clip_frac.item()) + if rollout_logprobs is not None: + pg_losses, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( + pg_losses, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) + loss_metrics.update(off_policy_correction_metrics) pg_loss = reduce_loss( loss=pg_losses, @@ -1086,7 +1094,7 @@ def compute_policy_loss_clip_cov( max_seq_len=config.max_seq_len, ) - return pg_loss, LossMetrics(clip_frac=clip_frac.item()) + return pg_loss, loss_metrics @register_policy_loss(PolicyLossType.KL_COV) @@ -1139,11 +1147,13 @@ def compute_policy_loss_kl_cov( ] # apply rollout correction - rollout_corr = config.rollout_correction - if rollout_corr is not None and rollout_logprobs is not None: - pg_losses, loss_metrics, loss_mask = apply_rollout_correction( - pg_losses, old_log_probs, rollout_logprobs, loss_mask, rollout_corr + off_policy_correction = config.off_policy_correction + loss_metrics = LossMetrics(clip_ratio=0.0) + if rollout_logprobs is not None: + pg_losses, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( + pg_losses, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) + loss_metrics.update(off_policy_correction_metrics) pg_loss = reduce_loss( loss=pg_losses, @@ -1153,7 +1163,7 @@ def compute_policy_loss_kl_cov( ) # NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0 - return pg_loss, 0.0 + return pg_loss, loss_metrics def reduce_loss( diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index b4cbabba5..6ef3b99d8 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -284,20 +284,20 @@ def validate_cfg(cfg: DictConfig): # Legacy TIS validation (deprecated) if cfg.trainer.algorithm.use_tis: logger.warning( - f"`trainer.algorithm.use_tis` is deprecated. Setting `trainer.algorithm.rollout_correction` to `token` instead." - f"with `token_tis_ratio_cap_high`={cfg.trainer.algorithm.tis_imp_ratio_cap}" + f"`trainer.algorithm.use_tis` is deprecated. Setting `trainer.algorithm.off_policy_correction` to `token` instead." + f"with `token_tis_ratio_clip_high`={cfg.trainer.algorithm.tis_imp_ratio_cap}" ) - cfg.trainer.algorithm.rollout_correction.tis_ratio_type = "token" - cfg.trainer.algorithm.rollout_correction.token_tis_ratio_cap_high = cfg.trainer.algorithm.tis_imp_ratio_cap + cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "token" + cfg.trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high = cfg.trainer.algorithm.tis_imp_ratio_cap - # rollout_correction config validation - rollout_corr = cfg.trainer.algorithm.rollout_correction - tis_ratio_type = rollout_corr.tis_ratio_type - rejection_mask_type = rollout_corr.rejection_mask_type + # off_policy_correction config validation + off_policy_correction = cfg.trainer.algorithm.off_policy_correction + tis_ratio_type = off_policy_correction.tis_ratio_type + sequence_mask_metric = off_policy_correction.sequence_mask_metric - uses_rollout_correction = tis_ratio_type is not None or rejection_mask_type is not None + uses_off_policy_correction = tis_ratio_type is not None or sequence_mask_metric is not None - if uses_rollout_correction: + if uses_off_policy_correction: # Validate tis_ratio_type if tis_ratio_type: assert tis_ratio_type in [ @@ -305,24 +305,24 @@ def validate_cfg(cfg: DictConfig): "sequence", ], f"`tis_ratio_type` must be 'None', 'token', or 'sequence', got {tis_ratio_type}" - # Validate rejection_mask_type - if rejection_mask_type: - assert rejection_mask_type in [ + # Validate sequence_mask_metric + if sequence_mask_metric: + assert sequence_mask_metric in [ "sequence", "geometric", - ], f"`rejection_mask_type` must be 'sequence', or 'geometric', got {rejection_mask_type}" + ], f"`sequence_mask_metric` must be 'sequence', or 'geometric', got {sequence_mask_metric}" # Ensure logprobs are enabled for rollout correction if cfg.generator.sampling_params.logprobs is None: logger.warning( - "`generator.sampling_params.logprobs` is `None` but rollout_correction is enabled." + "`generator.sampling_params.logprobs` is `None` but off_policy_correction is enabled." " Setting `logprobs` to `True`." ) cfg.generator.sampling_params.logprobs = 0 if cfg.generator.backend == "sglang": raise NotImplementedError( - "`trainer.algorithm.rollout_correction` doesn't support Sglang backend, please use vLLM" + "`trainer.algorithm.off_policy_correction` doesn't support Sglang backend, please use vLLM" ) if cfg.trainer.policy.model.lora.rank > 0: diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index 6a0491ccb..1d891641d 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -12,14 +12,14 @@ PolicyLossRegistry, masked_mean, compute_tis_ratio, - compute_rejection_mask, + compute_sequence_mask, compute_outlier_token_mask, - apply_rollout_correction, + apply_off_policy_correction, ) -NULL_ROLLOUT_CORR = { - "tis_ratio_type": "null", - "rejection_mask_type": "null", +NULL_OFF_POLICY_CORR = { + "tis_ratio_type": None, + "sequence_mask_metric": None, "outlier_token_is_threshold_low": 1e-4, "outlier_token_is_threshold_high": 100.0, } @@ -47,7 +47,7 @@ def test_policy_loss_dual_clip(): "policy_loss_type": "dual_clip", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -100,7 +100,7 @@ def test_policy_loss_cispo(): "policy_loss_type": "cispo", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -178,7 +178,7 @@ def test_policy_loss_reduction_modes(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -190,7 +190,7 @@ def test_policy_loss_reduction_modes(): "policy_loss_type": "regular", "loss_reduction": "sequence_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -263,7 +263,7 @@ def test_policy_loss_reduction_edge_cases(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -275,7 +275,7 @@ def test_policy_loss_reduction_edge_cases(): "policy_loss_type": "regular", "loss_reduction": "sequence_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -361,7 +361,7 @@ def test_gspo_importance_sampling_levels(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) ppo_loss_fn = PolicyLossRegistry.get("regular") @@ -376,7 +376,7 @@ def test_gspo_importance_sampling_levels(): "policy_loss_type": "gspo", "loss_reduction": "sequence_mean", # GSPO recommended reduction "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) gspo_loss_fn = PolicyLossRegistry.get("gspo") @@ -483,7 +483,7 @@ def test_clip_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "clip_cov": {"clip_ratio": 0.5, "clip_cov_lb": -5.0, "clip_cov_ub": 5.0}, # Large ratio for testing - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -492,11 +492,11 @@ def test_clip_cov_policy_loss(): # Calculate loss loss, loss_metrics = clip_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) - clip_frac = loss_metrics["clip_frac"] + clip_ratio = loss_metrics["clip_ratio"] # Basic sanity checks assert torch.isfinite(loss), "Loss should be finite" - assert 0 <= clip_frac <= 1, f"Clip fraction should be between 0 and 1, got {clip_frac}" + assert 0 <= clip_ratio <= 1, f"Clip ratio should be between 0 and 1, got {clip_ratio}" # Compare with regular PPO (should be different due to covariance correction) regular_config = DictConfig( @@ -506,7 +506,7 @@ def test_clip_cov_policy_loss(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -547,7 +547,7 @@ def test_kl_cov_policy_loss(): "loss_reduction": "token_mean", "max_seq_len": 4, "kl_cov": {"kl_cov_frac": 0.5, "ppo_kl_coef": 1.0}, # Apply KL to 50% of tokens - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -555,11 +555,11 @@ def test_kl_cov_policy_loss(): kl_cov_fn = PolicyLossRegistry.get("kl_cov") # Calculate loss - loss, clip_frac = kl_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) + loss, loss_metrics = kl_cov_fn(log_probs, old_log_probs, advantages, config, loss_mask) # Basic sanity checks assert torch.isfinite(loss), "Loss should be finite" - assert clip_frac == 0.0, "KL-Cov should return 0.0 for clipfrac value" + assert loss_metrics["clip_ratio"] == 0.0, "KL-Cov should return 0.0 for clip_ratio value" # Compare with regular PPO (should be different due to KL regularization) regular_config = DictConfig( @@ -569,7 +569,7 @@ def test_kl_cov_policy_loss(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -602,7 +602,7 @@ def test_sapo_policy_loss_basic(): "loss_reduction": "sequence_mean", "max_seq_len": 4, "sapo": {"tau_pos": 1.0, "tau_neg": 2.0}, - "rollout_correction": NULL_ROLLOUT_CORR, + "off_policy_correction": NULL_OFF_POLICY_CORR, } ) @@ -659,7 +659,7 @@ def test_compute_tis_ratio_token_level(): config = DictConfig( { "tis_ratio_type": "token", - "token_tis_ratio_cap_high": 2.0, + "token_tis_ratio_clip_high": 2.0, } ) @@ -669,8 +669,8 @@ def test_compute_tis_ratio_token_level(): expected = torch.tensor([[1.6487, 0.6065, 2.0]], device=device) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) # One token out of 3 was capped - assert "tis_token_capped_frac" in metrics - assert abs(metrics["tis_token_capped_frac"] - 1 / 3) < 0.01 + assert "tis_token_clip_high_ratio" in metrics + assert abs(metrics["tis_token_clip_high_ratio"] - 1 / 3) < 0.01 def test_compute_tis_ratio_sequence_level(): @@ -687,7 +687,7 @@ def test_compute_tis_ratio_sequence_level(): config = DictConfig( { "tis_ratio_type": "sequence", - "sequence_tis_ratio_cap_high": 5.0, + "sequence_tis_ratio_clip_high": 5.0, } ) @@ -697,8 +697,8 @@ def test_compute_tis_ratio_sequence_level(): expected = torch.tensor([[2.7183]], device=device) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) # No sequence was capped (2.7183 < 5.0) - assert "tis_seq_capped_frac" in metrics - assert metrics["tis_seq_capped_frac"] == 0.0 + assert "tis_seq_clip_high_ratio" in metrics + assert metrics["tis_seq_clip_high_ratio"] == 0.0 def test_compute_tis_ratio_sequence_level_with_cap(): @@ -715,7 +715,7 @@ def test_compute_tis_ratio_sequence_level_with_cap(): config = DictConfig( { "tis_ratio_type": "sequence", - "sequence_tis_ratio_cap_high": 5.0, + "sequence_tis_ratio_clip_high": 5.0, } ) @@ -725,8 +725,8 @@ def test_compute_tis_ratio_sequence_level_with_cap(): expected = torch.tensor([[5.0]], device=device) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) # One sequence out of 1 was capped - assert "tis_seq_capped_frac" in metrics - assert metrics["tis_seq_capped_frac"] == 1.0 + assert "tis_seq_clip_high_ratio" in metrics + assert metrics["tis_seq_clip_high_ratio"] == 1.0 def test_compute_tis_ratio_with_mask(): @@ -742,7 +742,7 @@ def test_compute_tis_ratio_with_mask(): config = DictConfig( { "tis_ratio_type": "sequence", - "sequence_tis_ratio_cap_high": 10.0, + "sequence_tis_ratio_clip_high": 10.0, } ) @@ -753,12 +753,12 @@ def test_compute_tis_ratio_with_mask(): expected = expected_val.reshape(1, 1) torch.testing.assert_close(tis_ratio, expected, rtol=1e-3, atol=1e-4) # No sequence was capped (4.4817 < 10.0) - assert "tis_seq_capped_frac" in metrics - assert metrics["tis_seq_capped_frac"] == 0.0 + assert "tis_seq_clip_high_ratio" in metrics + assert metrics["tis_seq_clip_high_ratio"] == 0.0 -def test_compute_rejection_mask_geometric(): - """Tests geometric rejection mask computation.""" +def test_compute_sequence_mask_geometric(): + """Tests geometric sequence mask computation.""" device = "cpu" # Token log ratios: [0.1, -0.1, 0.0] -> sum = 0.0, geometric mean = exp(0/3) = 1.0 @@ -768,26 +768,26 @@ def test_compute_rejection_mask_geometric(): config = DictConfig( { - "rejection_mask_type": "geometric", - "geo_rejection_mask_ratio_cap_high": 1.1, - "geo_rejection_mask_ratio_cap_low": 0.9, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, } ) - rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) # Geometric mean ≈ 1.0, which is within [0.9, 1.1], so mask should be 1.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[1.0]], device=device) - torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) # No sequence was masked - assert metrics["rejection_seq_masked_frac"] == 0.0 - assert metrics["rejection_seq_over_high_frac"] == 0.0 - assert metrics["rejection_seq_under_low_frac"] == 0.0 + assert metrics["geo_sequence_mask_masked_ratio"] == 0.0 + assert metrics["geo_sequence_mask_over_high_ratio"] == 0.0 + assert metrics["geo_sequence_mask_under_low_ratio"] == 0.0 -def test_compute_rejection_mask_geometric_rejects(): - """Tests geometric rejection mask correctly rejects sequences outside bounds.""" +def test_compute_sequence_mask_geometric_rejects(): + """Tests geometric sequence mask correctly rejects sequences outside bounds.""" device = "cpu" # Token log ratios: [0.5, 0.5, 0.5] -> sum = 1.5, geometric mean = exp(1.5/3) = exp(0.5) ≈ 1.6487 @@ -797,26 +797,26 @@ def test_compute_rejection_mask_geometric_rejects(): config = DictConfig( { - "rejection_mask_type": "geometric", - "geo_rejection_mask_ratio_cap_high": 1.1, - "geo_rejection_mask_ratio_cap_low": 0.9, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, } ) - rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "geometric", config) # Geometric mean ≈ 1.6487, which is outside [0.9, 1.1], so mask should be 0.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[0.0]], device=device) - torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) # One sequence masked, over high cap - assert metrics["rejection_seq_masked_frac"] == 1.0 - assert metrics["rejection_seq_over_high_frac"] == 1.0 - assert metrics["rejection_seq_under_low_frac"] == 0.0 + assert metrics["geo_sequence_mask_masked_ratio"] == 1.0 + assert metrics["geo_sequence_mask_over_high_ratio"] == 1.0 + assert metrics["geo_sequence_mask_under_low_ratio"] == 0.0 -def test_compute_rejection_mask_sequence(): - """Tests sequence rejection mask computation.""" +def test_compute_sequence_mask_sequence(): + """Tests sequence sequence mask computation.""" device = "cpu" # Token log ratios: [0.2, 0.1, 0.0] -> sum = 0.3, seq ratio = exp(0.3) ≈ 1.35 @@ -826,26 +826,26 @@ def test_compute_rejection_mask_sequence(): config = DictConfig( { - "rejection_mask_type": "sequence", - "sequence_rejection_mask_ratio_cap_high": 2.0, - "sequence_rejection_mask_ratio_cap_low": 0.5, + "sequence_mask_metric": "product", + "product_mask_high": 2.0, + "product_mask_low": 0.5, } ) - rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "product", config) # Sequence ratio ≈ 1.35, which is within [0.5, 2.0] # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[1.0]], device=device) - torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) # No sequence was masked - assert metrics["rejection_seq_masked_frac"] == 0.0 - assert metrics["rejection_seq_over_high_frac"] == 0.0 - assert metrics["rejection_seq_under_low_frac"] == 0.0 + assert metrics["product_sequence_mask_masked_ratio"] == 0.0 + assert metrics["product_sequence_mask_over_high_ratio"] == 0.0 + assert metrics["product_sequence_mask_under_low_ratio"] == 0.0 -def test_compute_rejection_mask_sequence_rejects_by_seq_ratio(): - """Tests sequence rejection mask rejects when sequence ratio is out of bounds.""" +def test_compute_sequence_mask_sequence_rejects_by_seq_ratio(): + """Tests product sequence mask rejects when sequence ratio is out of bounds.""" device = "cpu" # Token log ratios: [1.0, 1.0, 1.0] -> sum = 3.0, seq ratio = exp(3.0) ≈ 20.09 @@ -855,25 +855,25 @@ def test_compute_rejection_mask_sequence_rejects_by_seq_ratio(): config = DictConfig( { - "rejection_mask_type": "sequence", - "sequence_rejection_mask_ratio_cap_high": 2.0, - "sequence_rejection_mask_ratio_cap_low": 0.5, + "sequence_mask_metric": "product", + "product_mask_high": 2.0, + "product_mask_low": 0.5, } ) - rejection_mask, metrics = compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "sequence", config) + sequence_mask, metrics = compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "product", config) # Sequence ratio ≈ 20.09, which is outside [0.5, 2.0], so mask should be 0.0 # Shape is [batch, 1] for sequence-level mask expected = torch.tensor([[0.0]], device=device) - torch.testing.assert_close(rejection_mask, expected, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(sequence_mask, expected, rtol=1e-3, atol=1e-4) # One sequence masked, over high cap - assert metrics["rejection_seq_masked_frac"] == 1.0 - assert metrics["rejection_seq_over_high_frac"] == 1.0 - assert metrics["rejection_seq_under_low_frac"] == 0.0 + assert metrics["product_sequence_mask_masked_ratio"] == 1.0 + assert metrics["product_sequence_mask_over_high_ratio"] == 1.0 + assert metrics["product_sequence_mask_under_low_ratio"] == 0.0 -def test_compute_outlier_token_mask_rejects_by_token_bounds(): +def test_compute_outlier_token_mask_masks_by_token_bounds(): """Tests outlier token mask rejects when a token ratio is out of bounds.""" device = "cpu" @@ -886,7 +886,7 @@ def test_compute_outlier_token_mask_rejects_by_token_bounds(): config = DictConfig( { "outlier_token_is_threshold_low": 1e-4, - "outlier_token_is_threshold_high": 100.0, # This should cause rejection + "outlier_token_is_threshold_high": 100.0, # This should cause masking } ) @@ -897,9 +897,9 @@ def test_compute_outlier_token_mask_rejects_by_token_bounds(): expected = torch.tensor([[0.0]], device=device) torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) # One sequence masked, has token over high threshold - assert metrics["outlier_seq_masked_frac"] == 1.0 - assert metrics["outlier_seq_over_high_frac"] == 1.0 - assert metrics["outlier_seq_under_low_frac"] == 0.0 + assert metrics["outlier_seq_masked_ratio"] == 1.0 + assert metrics["outlier_seq_over_high_ratio"] == 1.0 + assert metrics["outlier_seq_under_low_ratio"] == 0.0 def test_compute_outlier_token_mask_accepts_in_bounds(): @@ -926,9 +926,9 @@ def test_compute_outlier_token_mask_accepts_in_bounds(): expected = torch.tensor([[1.0]], device=device) torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) # No sequence was masked - assert metrics["outlier_seq_masked_frac"] == 0.0 - assert metrics["outlier_seq_over_high_frac"] == 0.0 - assert metrics["outlier_seq_under_low_frac"] == 0.0 + assert metrics["outlier_seq_masked_ratio"] == 0.0 + assert metrics["outlier_seq_over_high_ratio"] == 0.0 + assert metrics["outlier_seq_under_low_ratio"] == 0.0 def test_compute_outlier_token_mask_respects_loss_mask(): @@ -954,11 +954,11 @@ def test_compute_outlier_token_mask_respects_loss_mask(): expected = torch.tensor([[1.0]], device=device) torch.testing.assert_close(outlier_mask, expected, rtol=1e-3, atol=1e-4) # No sequence was masked (the out-of-bounds token was in a masked position) - assert metrics["outlier_seq_masked_frac"] == 0.0 + assert metrics["outlier_seq_masked_ratio"] == 0.0 -def test_apply_rollout_correction_null_configs(): - """Tests that apply_rollout_correction returns loss unchanged when both configs are null.""" +def test_apply_off_policy_correction_null_configs(): + """Tests that apply_off_policy_correction returns loss unchanged when both configs are null.""" device = "cpu" loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) @@ -968,12 +968,12 @@ def test_apply_rollout_correction_null_configs(): config = DictConfig( { - "tis_ratio_type": "null", - "rejection_mask_type": "null", + "tis_ratio_type": None, + "sequence_mask_metric": None, } ) - corrected_loss, metrics, loss_mask = apply_rollout_correction( + corrected_loss, metrics, loss_mask = apply_off_policy_correction( loss, old_log_probs, rollout_logprobs, loss_mask, config ) @@ -982,8 +982,8 @@ def test_apply_rollout_correction_null_configs(): assert metrics == {} -def test_apply_rollout_correction_tis_only(): - """Tests apply_rollout_correction with only TIS enabled.""" +def test_apply_off_policy_correction_tis_only(): + """Tests apply_off_policy_correction with only TIS enabled.""" device = "cpu" loss = torch.tensor([[1.0, 1.0, 1.0]], device=device) @@ -995,14 +995,14 @@ def test_apply_rollout_correction_tis_only(): config = DictConfig( { "tis_ratio_type": "token", - "token_tis_ratio_cap_high": 2.0, - "rejection_mask_type": "null", + "token_tis_ratio_clip_high": 2.0, + "sequence_mask_metric": None, "outlier_token_is_threshold_low": 1e-4, "outlier_token_is_threshold_high": 100.0, } ) - corrected_loss, metrics, loss_mask = apply_rollout_correction( + corrected_loss, metrics, loss_mask = apply_off_policy_correction( loss, old_log_probs, rollout_logprobs, loss_mask, config ) @@ -1011,11 +1011,11 @@ def test_apply_rollout_correction_tis_only(): torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) # Check metrics are populated assert "is_ratio_mean" in metrics - assert "tis_token_capped_frac" in metrics + assert "tis_token_clip_high_ratio" in metrics -def test_apply_rollout_correction_rejection_only(): - """Tests apply_rollout_correction with only rejection mask enabled.""" +def test_apply_off_policy_correction_sequence_mask_only(): + """Tests apply_off_policy_correction with only geometric sequence mask enabled.""" device = "cpu" loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) @@ -1026,16 +1026,16 @@ def test_apply_rollout_correction_rejection_only(): config = DictConfig( { - "tis_ratio_type": "null", - "rejection_mask_type": "geometric", - "geo_rejection_mask_ratio_cap_high": 1.1, - "geo_rejection_mask_ratio_cap_low": 0.9, + "tis_ratio_type": None, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, "outlier_token_is_threshold_low": 1e-4, "outlier_token_is_threshold_high": 100.0, } ) - corrected_loss, metrics, loss_mask = apply_rollout_correction( + corrected_loss, metrics, loss_mask = apply_off_policy_correction( loss, old_log_probs, rollout_logprobs, loss_mask, config ) @@ -1043,11 +1043,11 @@ def test_apply_rollout_correction_rejection_only(): torch.testing.assert_close(corrected_loss, loss, rtol=1e-3, atol=1e-4) # Check metrics are populated assert "is_ratio_mean" in metrics - assert "rejection_seq_masked_frac" in metrics + assert "geo_sequence_mask_masked_ratio" in metrics -def test_apply_rollout_correction_both_enabled(): - """Tests apply_rollout_correction with both TIS and rejection mask enabled.""" +def test_apply_off_policy_correction_both_enabled(): + """Tests apply_off_policy_correction with both TIS and geometric sequence mask enabled.""" device = "cpu" loss = torch.tensor([[1.0, 1.0, 1.0]], device=device) @@ -1060,16 +1060,16 @@ def test_apply_rollout_correction_both_enabled(): config = DictConfig( { "tis_ratio_type": "token", - "token_tis_ratio_cap_high": 2.0, - "rejection_mask_type": "geometric", - "geo_rejection_mask_ratio_cap_high": 1.2, - "geo_rejection_mask_ratio_cap_low": 0.8, + "token_tis_ratio_clip_high": 2.0, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.2, + "geo_mask_low": 0.8, "outlier_token_is_threshold_low": 1e-4, "outlier_token_is_threshold_high": 100.0, } ) - corrected_loss, metrics, loss_mask = apply_rollout_correction( + corrected_loss, metrics, loss_mask = apply_off_policy_correction( loss, old_log_probs, rollout_logprobs, loss_mask, config ) @@ -1077,13 +1077,13 @@ def test_apply_rollout_correction_both_enabled(): # Expected: loss * 1.105 * 1.0 = loss * 1.105 expected = loss * torch.exp(torch.tensor(0.1)) torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) - # Check metrics from both TIS and rejection are populated - assert "tis_token_capped_frac" in metrics - assert "rejection_seq_masked_frac" in metrics + # Check metrics from both TIS and sequence mask are populated + assert "tis_token_clip_high_ratio" in metrics + assert "geo_sequence_mask_masked_ratio" in metrics -def test_apply_rollout_correction_rejection_zeros_loss(): - """Tests that rejection mask can zero out the loss entirely.""" +def test_apply_off_policy_correction_sequence_mask_zeros_loss(): + """Tests that sequence mask can zero out the loss entirely.""" device = "cpu" loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) @@ -1094,27 +1094,27 @@ def test_apply_rollout_correction_rejection_zeros_loss(): config = DictConfig( { - "tis_ratio_type": "null", - "rejection_mask_type": "geometric", - "geo_rejection_mask_ratio_cap_high": 1.1, - "geo_rejection_mask_ratio_cap_low": 0.9, + "tis_ratio_type": None, + "sequence_mask_metric": "geometric", + "geo_mask_high": 1.1, + "geo_mask_low": 0.9, "outlier_token_is_threshold_low": 1e-4, "outlier_token_is_threshold_high": 100.0, } ) - corrected_loss, metrics, loss_mask = apply_rollout_correction( + corrected_loss, metrics, loss_mask = apply_off_policy_correction( loss, old_log_probs, rollout_logprobs, loss_mask, config ) # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss should be zeroed expected = torch.tensor([[0.0, 0.0, 0.0]], device=device) torch.testing.assert_close(corrected_loss * loss_mask, expected, rtol=1e-3, atol=1e-4) - # Check that the rejection metrics show rejection happened - assert metrics["rejection_seq_masked_frac"] == 1.0 + # Check that the sequence mask metrics show sequence mask happened + assert metrics["geo_sequence_mask_masked_ratio"] == 1.0 -def test_ppo_policy_loss_with_rollout_correction(): +def test_ppo_policy_loss_with_off_policy_correction(): """Integration test for PPO policy loss with rollout correction enabled.""" device = "cpu" @@ -1131,10 +1131,10 @@ def test_ppo_policy_loss_with_rollout_correction(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": { + "off_policy_correction": { "tis_ratio_type": "token", - "token_tis_ratio_cap_high": 2.0, - "rejection_mask_type": "null", + "token_tis_ratio_clip_high": 2.0, + "sequence_mask_metric": None, "outlier_token_is_threshold_low": 1e-4, "outlier_token_is_threshold_high": 100.0, }, @@ -1161,9 +1161,9 @@ def test_ppo_policy_loss_with_rollout_correction(): "policy_loss_type": "regular", "loss_reduction": "token_mean", "max_seq_len": 4, - "rollout_correction": { - "tis_ratio_type": "null", - "rejection_mask_type": "null", + "off_policy_correction": { + "tis_ratio_type": None, + "sequence_mask_metric": None, }, } ) @@ -1198,15 +1198,15 @@ def test_compute_tis_ratio_invalid_type(): compute_tis_ratio(old_log_probs, rollout_logprobs, loss_mask, "invalid", config) -def test_compute_rejection_mask_invalid_type(): - """Tests that compute_rejection_mask raises error for invalid rejection_mask_type.""" +def test_compute_sequence_mask_invalid_type(): + """Tests that compute_sequence_mask raises error for invalid sequence_mask_metric.""" device = "cpu" old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) - config = DictConfig({"rejection_mask_type": "invalid"}) + config = DictConfig({"sequence_mask_metric": "invalid"}) - with pytest.raises(ValueError, match="Unknown rejection_mask_type"): - compute_rejection_mask(old_log_probs, rollout_logprobs, loss_mask, "invalid", config) + with pytest.raises(ValueError, match="Unknown sequence_mask_metric"): + compute_sequence_mask(old_log_probs, rollout_logprobs, loss_mask, "invalid", config) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py index 5131ae446..9b5bc6e3b 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py @@ -49,11 +49,11 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ cfg.trainer.algorithm.use_kl_loss = True cfg.trainer.algorithm.kl_loss_coef = 0.001 - cfg.trainer.algorithm.rollout_correction.tis_ratio_type = "sequence" + cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "sequence" - cfg.trainer.algorithm.rollout_correction.rejection_mask_type = "geometric" - cfg.trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_high = 1.02 - cfg.trainer.algorithm.rollout_correction.geo_rejection_mask_ratio_cap_low = 0.98 + cfg.trainer.algorithm.off_policy_correction.sequence_mask_metric = "geometric" + cfg.trainer.algorithm.off_policy_correction.geo_mask_high = 1.02 + cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 actor_group = init_worker_with_type( "policy", From cef7121aac5d5e6b9dbce0ee4fd3a66b9b8d2593 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 23:37:36 +0000 Subject: [PATCH 14/23] x --- skyrl-train/skyrl_train/distributed/strategy.py | 2 -- skyrl-train/skyrl_train/utils/utils.py | 2 +- .../workers/megatron/megatron_worker.py | 15 ++------------- skyrl-train/skyrl_train/workers/worker.py | 13 ++----------- skyrl-train/skyrl_train/workers/worker_utils.py | 12 ++++++++++++ 5 files changed, 17 insertions(+), 27 deletions(-) diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index 785722c93..503cfcce4 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -93,10 +93,8 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: data /= self.world_size dist.all_reduce(data, op=dist.ReduceOp.SUM) elif op == "max": - data = torch.max(data) dist.all_reduce(data, op=dist.ReduceOp.MAX) elif op == "min": - data = torch.min(data) dist.all_reduce(data, op=dist.ReduceOp.MIN) elif op == "sum": dist.all_reduce(data, op=dist.ReduceOp.SUM) diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 6ef3b99d8..76b9f7587 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -281,7 +281,7 @@ def validate_cfg(cfg: DictConfig): algorithm_config.kl_estimator_type = "k3" cfg.trainer.algorithm = algorithm_config - # Legacy TIS validation (deprecated) + # TODO (erictang000): remove this after deprecation period if cfg.trainer.algorithm.use_tis: logger.warning( f"`trainer.algorithm.use_tis` is deprecated. Setting `trainer.algorithm.off_policy_correction` to `token` instead." diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index cc5e1e10d..c10eaa9a2 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -30,7 +30,7 @@ from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl_train.training_batch import TrainingOutputBatch -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics from skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -598,18 +598,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # attach response_length status["response_length"] = micro_buffer[i]["num_actions"] - min_metrics = {k: v for k, v in status.items() if k.endswith("_min")} - max_metrics = {k: v for k, v in status.items() if k.endswith("_max")} - mean_metrics = { - k: v for k, v in status.items() if k not in min_metrics and k not in max_metrics - } - - status_mean = self.strategy.all_reduce(mean_metrics, op="mean") - status_min = self.strategy.all_reduce(min_metrics, op="min") - status_max = self.strategy.all_reduce(max_metrics, op="max") - status_mean.update(status_min) - status_mean.update(status_max) - status = status_mean + status_mean = all_reduce_metrics(status, self.strategy) status_list.append(status) for k, v in status.items(): diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index df73bdf33..f062f9b20 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -32,7 +32,7 @@ from loguru import logger from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics, all_reduce_metrics from skyrl_train.dataset.replay_buffer import Experience from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient @@ -742,16 +742,7 @@ def record_status(status: Dict[str, float]): # for DP # TODO (sumanthrh): this assumes all workers are data parallel. # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. - min_metrics = {k: v for k, v in status.items() if k.endswith("_min")} - max_metrics = {k: v for k, v in status.items() if k.endswith("_max")} - mean_metrics = {k: v for k, v in status.items() if k not in min_metrics and k not in max_metrics} - - status_mean = self.strategy.all_reduce(mean_metrics, op="mean") - status_min = self.strategy.all_reduce(min_metrics, op="min") - status_max = self.strategy.all_reduce(max_metrics, op="max") - status_mean.update(status_min) - status_mean.update(status_max) - status = status_mean + status = all_reduce_metrics(status, self.strategy) # weighted mean for kl # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 28dbb5888..32410d17d 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -2,6 +2,7 @@ from skyrl_train.dataset.replay_buffer import Experience from typing import List, Dict from skyrl_train.training_batch import TrainingInputBatch +from skyrl_train.distributed.strategy import DistributedStrategy def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: @@ -20,6 +21,17 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: reduced_metrics[k] = sum(v) / len(v) return reduced_metrics +def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]: + """All reduce metrics across all processes.""" + min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} + max_metrics = {k: v for k, v in metrics.items() if k.endswith("_max")} + mean_metrics = {k: v for k, v in metrics.items() if k not in min_metrics and k not in max_metrics} + status_mean = strategy.all_reduce(mean_metrics, op="mean") + status_min = strategy.all_reduce(min_metrics, op="min") + status_max = strategy.all_reduce(max_metrics, op="max") + status_mean.update(status_min) + status_mean.update(status_max) + return status_mean class BatchIterator: """A simple iterator to yield micro batches of data from the training batch.""" From 9485bddb6255cc0026381ccebec11c058756ce01 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 8 Jan 2026 23:37:47 +0000 Subject: [PATCH 15/23] x --- skyrl-train/skyrl_train/workers/worker_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 32410d17d..ebc25f746 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -21,6 +21,7 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: reduced_metrics[k] = sum(v) / len(v) return reduced_metrics + def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStrategy) -> Dict[str, float]: """All reduce metrics across all processes.""" min_metrics = {k: v for k, v in metrics.items() if k.endswith("_min")} @@ -33,6 +34,7 @@ def all_reduce_metrics(metrics: Dict[str, List[float]], strategy: DistributedStr status_mean.update(status_max) return status_mean + class BatchIterator: """A simple iterator to yield micro batches of data from the training batch.""" From c06747c3a2547ca283b3fbf14f2bb68ad6942c1d Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 12 Jan 2026 22:59:19 +0000 Subject: [PATCH 16/23] x --- skyrl-train/skyrl_train/utils/ppo_utils.py | 91 +++++++++---------- skyrl-train/skyrl_train/utils/utils.py | 4 + .../tests/cpu/algorithms/test_losses.py | 74 +++++++-------- 3 files changed, 82 insertions(+), 87 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index e738c9fea..c192e8e4e 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -562,9 +562,9 @@ def compute_tis_ratio( loss_mask: torch.Tensor, tis_ratio_type: str, off_policy_correction: DictConfig, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, dict]: """ - Compute truncated importance sampling (TIS) ratio for rollout correction. + Compute truncated importance sampling (TIS) ratio for off policy correction. Args: old_log_probs: Log probabilities from the old policy (before update). @@ -574,7 +574,9 @@ def compute_tis_ratio( off_policy_correction: Off-policy correction config containing cap values. Returns: - TIS ratio tensor to multiply with the loss. + Tuple of (tis_ratio, metrics): + - tis_ratio: Tensor (float) to multiply with the loss + - metrics: Dict with masking statistics Reference: https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md """ @@ -590,7 +592,7 @@ def compute_tis_ratio( tokens_capped = (token_tis_ratio > token_tis_ratio_cap) & (loss_mask > 0) total_tokens = (loss_mask > 0).sum() metrics["tis_token_clip_high_ratio"] = (tokens_capped.sum() / total_tokens.clamp(min=1)).detach().item() - return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap), metrics + return torch.clamp(token_tis_ratio, max=token_tis_ratio_cap).detach(), metrics elif tis_ratio_type == "sequence": # Compute sequence-level importance ratio as product of token ratios (sum of log ratios) seq_tis_log_ratio = (token_tis_log_ratio * loss_mask).sum(dim=-1, keepdim=True) @@ -600,7 +602,7 @@ def compute_tis_ratio( num_sequences = seq_tis_ratio.shape[0] seqs_capped = (seq_tis_ratio > seq_tis_ratio_cap).sum() metrics["tis_seq_clip_high_ratio"] = (seqs_capped / num_sequences).detach().item() - return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap), metrics + return torch.clamp(seq_tis_ratio, max=seq_tis_ratio_cap).detach(), metrics else: raise ValueError(f"Unknown tis_ratio_type: {tis_ratio_type}") @@ -610,7 +612,7 @@ def compute_outlier_token_mask( rollout_logprobs: torch.Tensor, loss_mask: torch.Tensor, off_policy_correction: DictConfig, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, dict]: """ Compute outlier token mask that masks out sequences with any token having importance ratio outside acceptable bounds. @@ -626,7 +628,7 @@ def compute_outlier_token_mask( Returns: Tuple of (outlier_mask, metrics): - - outlier_mask: Tensor (float) to multiply with the loss, shape [batch, 1] + - outlier_mask: Tensor (bool) to mask out sequences with any token having importance ratio outside acceptable bounds - metrics: Dict with masking statistics """ metrics = {} @@ -724,28 +726,29 @@ def compute_sequence_mask( raise ValueError(f"Unknown sequence_mask_metric: {sequence_mask_metric}") -def apply_off_policy_correction( - loss: torch.Tensor, +def compute_off_policy_correction( old_log_probs: torch.Tensor, rollout_logprobs: torch.Tensor, loss_mask: torch.Tensor, off_policy_correction: DictConfig, -) -> torch.Tensor: +) -> Tuple[Optional[torch.Tensor], dict, torch.Tensor]: """ - Apply off policy correction to the loss using TIS ratio, sequence mask, and outlier token mask. + Compute TIS ratio, sequence mask, and outlier token mask for off policy correction. This is a convenience function that combines compute_tis_ratio, compute_sequence_mask, and compute_outlier_token_mask. Args: - loss: The loss tensor to correct. old_log_probs: Log probabilities from the old policy (before update). rollout_logprobs: Log probabilities from the rollout policy. loss_mask: Mask indicating valid tokens. off_policy_correction: Off-policy correction config. Returns: - Corrected loss tensor. + Tuple of (tis_ratio, metrics, loss_mask): + - tis_ratio: Tensor (float) to multiply with the loss + - metrics: Dict with masking statistics + - loss_mask: Mask indicating valid tokens after applying off policy correction References: - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md @@ -761,7 +764,7 @@ def apply_off_policy_correction( # Early return if no correction needed if not apply_tis and not apply_sequence_mask: - return loss, {}, loss_mask + return None, {}, loss_mask is_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=old_log_probs.dtype) metrics = {} @@ -778,12 +781,14 @@ def apply_off_policy_correction( loss_mask = loss_mask * outlier_mask metrics.update(outlier_metrics) + # Initialize tis_ratio to None (only set if TIS is enabled) + tis_ratio = None + # Apply TIS ratio if enabled if apply_tis: tis_ratio, tis_metrics = compute_tis_ratio( old_log_probs, rollout_logprobs, loss_mask, tis_ratio_type, off_policy_correction ) - loss = loss * tis_ratio metrics.update(tis_metrics) # Apply sequence mask if enabled @@ -794,7 +799,7 @@ def apply_off_policy_correction( loss_mask = loss_mask * sequence_mask metrics.update(sequence_mask_metrics) - return loss, metrics, loss_mask + return tis_ratio, metrics, loss_mask @register_policy_loss(PolicyLossType.REGULAR) @@ -828,12 +833,14 @@ def ppo_policy_loss( loss_metrics = LossMetrics(clip_ratio=clip_ratio) - # apply rollout correction + # apply off policy correction off_policy_correction = config.off_policy_correction if rollout_logprobs is not None: - loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) + if tis_ratio is not None: + loss = loss * tis_ratio loss_metrics.update(off_policy_correction_metrics) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) @@ -899,13 +906,15 @@ def gate_function(x, tau): # compute policy gradient loss loss = -gates * advantages - # apply rollout correction + # apply off policy correction off_policy_correction = config.off_policy_correction loss_metrics = LossMetrics(clip_ratio=0.0) if rollout_logprobs is not None: - loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) + if tis_ratio is not None: + loss = loss * tis_ratio loss_metrics.update(off_policy_correction_metrics) # for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) @@ -968,13 +977,15 @@ def gspo_policy_loss( # Compute clipping ratio for monitoring clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() - # apply rollout correction + # apply off policy correction loss_metrics = LossMetrics(clip_ratio=clip_ratio) off_policy_correction = config.off_policy_correction if rollout_logprobs is not None: - loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) + if tis_ratio is not None: + loss = loss * tis_ratio loss_metrics.update(off_policy_correction_metrics) loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) @@ -1007,13 +1018,15 @@ def compute_policy_loss_cispo( is_clipped = (ratio < 1 - config.cispo.cispo_eps_clip_low) | (ratio > 1 + config.cispo.cispo_eps_clip_high) clip_ratio = masked_mean(is_clipped.float(), loss_mask).mean().detach().item() - # apply rollout correction + # apply off policy correction off_policy_correction = config.off_policy_correction loss_metrics = LossMetrics(clip_ratio=clip_ratio) if rollout_logprobs is not None: - loss, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction + tis_ratio, off_policy_correction_metrics, loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, off_policy_correction ) + if tis_ratio is not None: + loss = loss * tis_ratio loss_metrics.update(off_policy_correction_metrics) loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len) @@ -1078,15 +1091,6 @@ def compute_policy_loss_clip_cov( # Apply correction mask to losses pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr - # apply rollout correction - off_policy_correction = config.off_policy_correction - loss_metrics = LossMetrics(clip_ratio=clip_frac.item()) - if rollout_logprobs is not None: - pg_losses, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( - pg_losses, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction - ) - loss_metrics.update(off_policy_correction_metrics) - pg_loss = reduce_loss( loss=pg_losses, loss_mask=loss_mask, @@ -1094,7 +1098,7 @@ def compute_policy_loss_clip_cov( max_seq_len=config.max_seq_len, ) - return pg_loss, loss_metrics + return pg_loss, LossMetrics(clip_ratio=clip_frac.item()) @register_policy_loss(PolicyLossType.KL_COV) @@ -1146,15 +1150,6 @@ def compute_policy_loss_kl_cov( large_cov_idxs % advantages.shape[1], ] - # apply rollout correction - off_policy_correction = config.off_policy_correction - loss_metrics = LossMetrics(clip_ratio=0.0) - if rollout_logprobs is not None: - pg_losses, off_policy_correction_metrics, loss_mask = apply_off_policy_correction( - pg_losses, old_log_probs, rollout_logprobs, loss_mask, off_policy_correction - ) - loss_metrics.update(off_policy_correction_metrics) - pg_loss = reduce_loss( loss=pg_losses, loss_mask=loss_mask, @@ -1163,7 +1158,7 @@ def compute_policy_loss_kl_cov( ) # NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0 - return pg_loss, loss_metrics + return pg_loss, LossMetrics(clip_ratio=0.0) def reduce_loss( diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 76b9f7587..c1d286575 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -324,6 +324,10 @@ def validate_cfg(cfg: DictConfig): raise NotImplementedError( "`trainer.algorithm.off_policy_correction` doesn't support Sglang backend, please use vLLM" ) + if cfg.algorithm.policy_loss_type in ["clip_cov", "kl_cov"]: + raise NotImplementedError( + "`trainer.algorithm.off_policy_correction` doesn't support clip_cov or kl_cov policy loss types" + ) if cfg.trainer.policy.model.lora.rank > 0: # LoRA enabled diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index 1d891641d..e68251814 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -14,7 +14,7 @@ compute_tis_ratio, compute_sequence_mask, compute_outlier_token_mask, - apply_off_policy_correction, + compute_off_policy_correction, ) NULL_OFF_POLICY_CORR = { @@ -957,11 +957,10 @@ def test_compute_outlier_token_mask_respects_loss_mask(): assert metrics["outlier_seq_masked_ratio"] == 0.0 -def test_apply_off_policy_correction_null_configs(): - """Tests that apply_off_policy_correction returns loss unchanged when both configs are null.""" +def test_compute_off_policy_correction_null_configs(): + """Tests that compute_off_policy_correction returns None tis_ratio when both configs are null.""" device = "cpu" - loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) loss_mask = torch.tensor([[1.0, 1.0, 1.0]], device=device) @@ -973,20 +972,19 @@ def test_apply_off_policy_correction_null_configs(): } ) - corrected_loss, metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, config + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config ) - # Should return the same tensor (early return) and empty metrics - assert corrected_loss is loss + # Should return None tis_ratio (early return) and empty metrics + assert tis_ratio is None assert metrics == {} -def test_apply_off_policy_correction_tis_only(): - """Tests apply_off_policy_correction with only TIS enabled.""" +def test_compute_off_policy_correction_tis_only(): + """Tests compute_off_policy_correction with only TIS enabled.""" device = "cpu" - loss = torch.tensor([[1.0, 1.0, 1.0]], device=device) # Token log ratios: [0.5, 0.5, 0.5] -> token ratios = [1.6487, 1.6487, 1.6487] old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) rollout_logprobs = torch.tensor([[-1.5, -1.5, -1.5]], device=device) @@ -1002,23 +1000,22 @@ def test_apply_off_policy_correction_tis_only(): } ) - corrected_loss, metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, config + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config ) - # Expected: loss * 1.6487 (no capping needed) - expected = loss * torch.exp(torch.tensor(0.5)) - torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + # Expected tis_ratio: 1.6487 (no capping needed) + expected_tis_ratio = torch.exp(torch.tensor(0.5)) + torch.testing.assert_close(tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4) # Check metrics are populated assert "is_ratio_mean" in metrics assert "tis_token_clip_high_ratio" in metrics -def test_apply_off_policy_correction_sequence_mask_only(): - """Tests apply_off_policy_correction with only geometric sequence mask enabled.""" +def test_compute_off_policy_correction_sequence_mask_only(): + """Tests compute_off_policy_correction with only geometric sequence mask enabled.""" device = "cpu" - loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) # Token log ratios: [0.0, 0.0, 0.0] -> geometric mean = 1.0 old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) rollout_logprobs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) @@ -1035,22 +1032,23 @@ def test_apply_off_policy_correction_sequence_mask_only(): } ) - corrected_loss, metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, config + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config ) - # Geometric mean = 1.0, within bounds, so loss unchanged - torch.testing.assert_close(corrected_loss, loss, rtol=1e-3, atol=1e-4) + # Geometric mean = 1.0, within bounds, so loss_mask unchanged + # tis_ratio is None since tis_ratio_type is None + assert tis_ratio is None + torch.testing.assert_close(new_loss_mask, loss_mask, rtol=1e-3, atol=1e-4) # Check metrics are populated assert "is_ratio_mean" in metrics assert "geo_sequence_mask_masked_ratio" in metrics -def test_apply_off_policy_correction_both_enabled(): - """Tests apply_off_policy_correction with both TIS and geometric sequence mask enabled.""" +def test_compute_off_policy_correction_both_enabled(): + """Tests compute_off_policy_correction with both TIS and geometric sequence mask enabled.""" device = "cpu" - loss = torch.tensor([[1.0, 1.0, 1.0]], device=device) # Token log ratios: [0.1, 0.1, 0.1] -> token ratios ≈ [1.105, 1.105, 1.105] # Geometric mean ≈ 1.105 old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) @@ -1069,24 +1067,22 @@ def test_apply_off_policy_correction_both_enabled(): } ) - corrected_loss, metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, config + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config ) # TIS ratio ≈ 1.105, geometric mean ≈ 1.105 (within bounds, mask=1) - # Expected: loss * 1.105 * 1.0 = loss * 1.105 - expected = loss * torch.exp(torch.tensor(0.1)) - torch.testing.assert_close(corrected_loss, expected, rtol=1e-3, atol=1e-4) + expected_tis_ratio = torch.exp(torch.tensor(0.1)) + torch.testing.assert_close(tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4) # Check metrics from both TIS and sequence mask are populated assert "tis_token_clip_high_ratio" in metrics assert "geo_sequence_mask_masked_ratio" in metrics -def test_apply_off_policy_correction_sequence_mask_zeros_loss(): - """Tests that sequence mask can zero out the loss entirely.""" +def test_compute_off_policy_correction_sequence_mask_zeros_loss(): + """Tests that sequence mask can zero out the loss_mask entirely.""" device = "cpu" - loss = torch.tensor([[1.0, 2.0, 3.0]], device=device) # Token log ratios: [1.0, 1.0, 1.0] -> geometric mean = exp(1.0) ≈ 2.718 old_log_probs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) rollout_logprobs = torch.tensor([[-2.0, -2.0, -2.0]], device=device) @@ -1103,13 +1099,13 @@ def test_apply_off_policy_correction_sequence_mask_zeros_loss(): } ) - corrected_loss, metrics, loss_mask = apply_off_policy_correction( - loss, old_log_probs, rollout_logprobs, loss_mask, config + tis_ratio, metrics, new_loss_mask = compute_off_policy_correction( + old_log_probs, rollout_logprobs, loss_mask, config ) - # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss should be zeroed - expected = torch.tensor([[0.0, 0.0, 0.0]], device=device) - torch.testing.assert_close(corrected_loss * loss_mask, expected, rtol=1e-3, atol=1e-4) + # Geometric mean ≈ 2.718, outside [0.9, 1.1], so loss_mask should be zeroed + expected_mask = torch.tensor([[0.0, 0.0, 0.0]], device=device) + torch.testing.assert_close(new_loss_mask, expected_mask, rtol=1e-3, atol=1e-4) # Check that the sequence mask metrics show sequence mask happened assert metrics["geo_sequence_mask_masked_ratio"] == 1.0 From 0b5ebfd7bc251df83bc5d9b57e38e1666ee1e35d Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Mon, 12 Jan 2026 23:31:13 +0000 Subject: [PATCH 17/23] x --- skyrl-train/tests/cpu/algorithms/test_losses.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index e68251814..a4abce5b8 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -1006,7 +1006,9 @@ def test_compute_off_policy_correction_tis_only(): # Expected tis_ratio: 1.6487 (no capping needed) expected_tis_ratio = torch.exp(torch.tensor(0.5)) - torch.testing.assert_close(tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4) + torch.testing.assert_close( + tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4 + ) # Check metrics are populated assert "is_ratio_mean" in metrics assert "tis_token_clip_high_ratio" in metrics @@ -1073,7 +1075,9 @@ def test_compute_off_policy_correction_both_enabled(): # TIS ratio ≈ 1.105, geometric mean ≈ 1.105 (within bounds, mask=1) expected_tis_ratio = torch.exp(torch.tensor(0.1)) - torch.testing.assert_close(tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4) + torch.testing.assert_close( + tis_ratio, torch.full_like(old_log_probs, expected_tis_ratio.item()), rtol=1e-3, atol=1e-4 + ) # Check metrics from both TIS and sequence mask are populated assert "tis_token_clip_high_ratio" in metrics assert "geo_sequence_mask_masked_ratio" in metrics From 0697957e371449315ea28cc655bc7cac10c13058 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 13 Jan 2026 00:17:36 +0000 Subject: [PATCH 18/23] x --- skyrl-train/examples/megatron/run_megatron.sh | 21 +++++++++++++++---- skyrl-train/skyrl_train/utils/utils.py | 2 +- .../tests/gpu/gpu_ci/test_megatron_worker.py | 13 ++++++++++++ .../tests/gpu/gpu_ci/test_ppo_train.py | 9 ++++---- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron.sh b/skyrl-train/examples/megatron/run_megatron.sh index cf0d9f9ed..1ac4f2ca5 100644 --- a/skyrl-train/examples/megatron/run_megatron.sh +++ b/skyrl-train/examples/megatron/run_megatron.sh @@ -13,8 +13,8 @@ MODEL_NAME="Qwen/Qwen3-0.6B" INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron -MEGATRON_TP=2 -MEGATRON_PP=2 +MEGATRON_TP=1 +MEGATRON_PP=1 MEGATRON_CP=1 # torch profiler config @@ -22,10 +22,23 @@ ENABLE_TORCH_PROFILER=false RANKS_TO_PROFILE="[0]" SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" +TIS_RATIO_TYPE="sequence" +TIS_RATIO_HIGH=2.0 +SEQUENCE_MASK_METRIC="geometric" +SEQUENCE_MASK_HIGH=1.02 +SEQUENCE_MASK_LOW=0.98 + + uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ trainer.algorithm.advantage_estimator="grpo" \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_RATIO_TYPE \ + trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_RATIO_HIGH \ + trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=$TIS_RATIO_HIGH \ + trainer.algorithm.off_policy_correction.sequence_mask_metric=$SEQUENCE_MASK_METRIC \ + trainer.algorithm.off_policy_correction.geo_mask_high=$SEQUENCE_MASK_HIGH \ + trainer.algorithm.off_policy_correction.geo_mask_low=$SEQUENCE_MASK_LOW \ trainer.policy.model.path=$MODEL_NAME \ trainer.placement.colocate_all=true \ trainer.strategy=megatron \ @@ -48,8 +61,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.eval_before_train=false \ trainer.eval_interval=5 \ trainer.update_epochs_per_batch=1 \ - trainer.train_batch_size=128 \ - trainer.policy_mini_batch_size=64 \ + trainer.train_batch_size=64 \ + trainer.policy_mini_batch_size=16 \ trainer.micro_forward_batch_size_per_gpu=4 \ trainer.micro_train_batch_size_per_gpu=4 \ trainer.ckpt_interval=10 \ diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index c1d286575..1655dc85d 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -324,7 +324,7 @@ def validate_cfg(cfg: DictConfig): raise NotImplementedError( "`trainer.algorithm.off_policy_correction` doesn't support Sglang backend, please use vLLM" ) - if cfg.algorithm.policy_loss_type in ["clip_cov", "kl_cov"]: + if cfg.trainer.algorithm.policy_loss_type in ["clip_cov", "kl_cov"]: raise NotImplementedError( "`trainer.algorithm.off_policy_correction` doesn't support clip_cov or kl_cov policy loss types" ) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 7ecd9eb9b..028b2d0bd 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -472,6 +472,11 @@ async def test_megatron_train( ) transformer_config_kwargs["num_layers"] = 2 cfg.trainer.policy.megatron_config.transformer_config_kwargs = transformer_config_kwargs + # test off policy correction config propagates correctly + cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "sequence" + cfg.trainer.algorithm.off_policy_correction.sequence_mask_metric = "geometric" + cfg.trainer.algorithm.off_policy_correction.geo_mask_high = 1.02 + cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 # set batch sizes correctly cfg.trainer.train_batch_size = gpus_per_node @@ -547,6 +552,14 @@ async def test_megatron_train( "policy_kl", "final_loss", ] + if ep > 1: + keys_to_compare.extend( + [ + "loss_metrics/is_ratio_mean", + "loss_metrics/outlier_seq_masked_ratio", + "loss_metrics/geo_sequence_mask_masked_ratio", + ] + ) for i, result in enumerate(results_fsdp): for k in keys_to_compare: if k == "policy_entropy": diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py index 9b5bc6e3b..429ff5ede 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py @@ -29,8 +29,7 @@ def cfg() -> DictConfig: return cfg -# @pytest.mark.parametrize("use_entropy_loss, use_kl_loss", [(False, False), (True, True), (True, False), (False, True)]) -@pytest.mark.parametrize("use_entropy_loss, use_kl_loss", [(False, False)]) +@pytest.mark.parametrize("use_entropy_loss, use_kl_loss", [(False, False), (True, True), (True, False), (False, True)]) def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_kl_loss): """ Test that ppo_train runs and returns correct structure. @@ -50,7 +49,6 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ cfg.trainer.algorithm.kl_loss_coef = 0.001 cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "sequence" - cfg.trainer.algorithm.off_policy_correction.sequence_mask_metric = "geometric" cfg.trainer.algorithm.off_policy_correction.geo_mask_high = 1.02 cfg.trainer.algorithm.off_policy_correction.geo_mask_low = 0.98 @@ -82,6 +80,9 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ "policy_update_steps", "policy_lr", "loss_metrics/clip_ratio", + "loss_metrics/is_ratio_mean", + "loss_metrics/outlier_seq_masked_ratio", + "loss_metrics/geo_sequence_mask_masked_ratio", "policy_entropy", "final_loss", ] @@ -90,8 +91,6 @@ def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_ assert metric in train_status, f"Should have {metric} in train_status" assert isinstance(train_status[metric], (int, float)), f"{metric} should be numeric" - print(train_status) - # Simple check for metric values assert train_status["policy_update_steps"] > 0, "Should have completed at least one update step" assert train_status["policy_lr"] > 0, "Should have positive learning rate" From 6b9e1e4cd90bffcf8d7925992f5bff77e6dddd3d Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 13 Jan 2026 00:27:05 +0000 Subject: [PATCH 19/23] add docs --- skyrl-train/docs/configuration/config.rst | 76 ++++++++++++++++++++--- 1 file changed, 68 insertions(+), 8 deletions(-) diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 4da00d4e0..65c35870c 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -389,6 +389,45 @@ Algorithm Configuration # dual clip parameters clip_ratio_c: 3.0 + # To be deprecated in favor of off_policy_correction.tis_ratio_type = "token" + # and "token_tis_ratio_clip_high" + use_tis: false + tis_imp_ratio_cap: -1.0 + + # references + # - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md + # - https://fengyao.notion.site/off-policy-rl + off_policy_correction: + # type of importance sampling ratio to use for ppo loss correction + # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy}) + tis_ratio_type: null # null, "token", "sequence" + + # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type + token_tis_ratio_clip_high: 2.0 + # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type + sequence_tis_ratio_clip_high: 5.0 + + # method of masking out sequences with cumulative importance sampling ratios outside the cap + # "product" masks out sequences with product of importance ratios outside the cap + # "geometric" masks out sequences with geometric mean of importance ratios outside the cap + sequence_mask_metric: null # null, "product", "geometric" + + # used if sequence_mask_metric = "geometric" + # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch + geo_mask_high: 1.01 + geo_mask_low: 0.99 + + # used if sequence_mask_metric = "product" + # values around 0.5-2.0 are recommended for "sequence" sequence_mask_metric + product_mask_high: 2.0 + product_mask_low: 0.5 + + # separate from sequence_mask_metric and tis_ratio_type + # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio + # far outside an acceptable range (low and high thresholds) + outlier_token_is_threshold_low: 1e-4 + outlier_token_is_threshold_high: 100 + # clip-cov parameters (only used when policy_loss_type: "clip_cov") clip_cov: clip_ratio: 0.0002 # fraction of tokens to clip based on covariance @@ -413,10 +452,6 @@ Algorithm Configuration type: null # filter (DAPO), replace (POLARIS/WebSailor), or null max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only) - - # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl - use_tis: false - tis_imp_ratio_cap: -1.0 # SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347) sapo: @@ -466,8 +501,8 @@ Algorithm Configuration - ``algorithm.dynamic_sampling.type``: Type of dynamic sampling to use. Currently, we support ``filter`` (`DAPO `_), ``replace`` (`POLARIS `_ / `WebSailor `_), or ``null`` for no dynamic sampling. - ``algorithm.dynamic_sampling.max_sample_batches``: Maximum number of batches to sample before stopping. Set to ``-1`` to sample forever. - ``algorithm.dynamic_sampling.min_replace_ratio``: Minimum proportion of good samples with which to replace bad samples for ``replace`` strategy. -- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog `_. -- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS. +- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog `_. This flag is to be deprecated, use ``off_policy_correction.tis_ratio_type = "token"`` instead. +- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS. This flag is to be deprecated, use ``off_policy_correction.token_tis_ratio_clip_high`` instead. - ``algorithm.clip_cov``: Clip-Cov parameters (only used when ``policy_loss_type`` is ``clip_cov``): - ``clip_ratio``: Fraction of tokens to clip based on covariance values. @@ -489,9 +524,34 @@ Algorithm Configuration - ``tau_pos``: Temperature for gating function for tokens with positive advantages. - ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages. -Rollout Correction Configuration +Off Policy Correction Configuration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- ``algorithm.off_policy_correction``: Off policy correction configuration. +- ``algorithm.off_policy_correction``: Off policy correction configuration. See the full configuration below + +.. code-block:: yaml + + off_policy_correction: + tis_ratio_type: null # null, "token", "sequence" + token_tis_ratio_clip_high: 2.0 + sequence_tis_ratio_clip_high: 5.0 + sequence_mask_metric: null # null, "product", "geometric" + geo_mask_high: 1.01 + geo_mask_low: 0.99 + product_mask_high: 2.0 + product_mask_low: 0.5 + outlier_token_is_threshold_low: 1e-4 + outlier_token_is_threshold_high: 100 + +- ``algorithm.off_policy_correction.tis_ratio_type``: Type of importance sampling ratio to use for ppo loss correction. Options include: ``null``, ``token``, ``sequence``. +- ``algorithm.off_policy_correction.token_tis_ratio_clip_high``: Cap parameter for "token" tis_ratio_type. +- ``algorithm.off_policy_correction.sequence_tis_ratio_clip_high``: Cap parameter for "sequence" tis_ratio_type. +- ``algorithm.off_policy_correction.sequence_mask_metric``: Method of masking out sequences with cumulative importance sampling ratios outside the cap. Options include: ``null``, ``product``, ``geometric``. +- ``algorithm.off_policy_correction.geo_mask_high``: High threshold for "geometric" sequence_mask_metric. +- ``algorithm.off_policy_correction.geo_mask_low``: Low threshold for "geometric" sequence_mask_metric. +- ``algorithm.off_policy_correction.product_mask_high``: High threshold for "product" sequence_mask_metric. +- ``algorithm.off_policy_correction.product_mask_low``: Low threshold for "product" sequence_mask_metric. +- ``algorithm.off_policy_correction.outlier_token_is_threshold_low``: Low threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds). +- ``algorithm.off_policy_correction.outlier_token_is_threshold_high``: High threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds). Policy Loss Formulation ~~~~~~~~~~~~~~~~~~~~~~~ From 46b6fe5d509fd26c9db2e442d49dc7b796994e2b Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 13 Jan 2026 00:48:02 +0000 Subject: [PATCH 20/23] x --- skyrl-train/docs/configuration/config.rst | 2 +- .../examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh | 1 - .../flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh | 1 - skyrl-train/skyrl_train/config/ppo_base_config.yaml | 2 +- skyrl-train/skyrl_train/distributed/strategy.py | 9 +++++---- skyrl-train/skyrl_train/utils/utils.py | 4 ++-- .../skyrl_train/workers/megatron/megatron_worker.py | 2 +- skyrl-train/tests/cpu/algorithms/test_losses.py | 8 ++++---- 8 files changed, 14 insertions(+), 15 deletions(-) diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 65c35870c..740be7bb2 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -418,7 +418,7 @@ Algorithm Configuration geo_mask_low: 0.99 # used if sequence_mask_metric = "product" - # values around 0.5-2.0 are recommended for "sequence" sequence_mask_metric + # values around 0.5-2.0 are recommended for "product" sequence_mask_metric product_mask_high: 2.0 product_mask_low: 0.5 diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh index bd3844397..9d16cc709 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh @@ -57,7 +57,6 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with v trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh index cdfa54ef9..5c31645bc 100644 --- a/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh +++ b/skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh @@ -57,7 +57,6 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \ trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ - trainer.algorithm.tis_imp_ratio_cap=2.0 \ trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ trainer.placement.colocate_all=true \ trainer.strategy=fsdp2 \ diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 28b4da504..0e8b6385c 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -150,7 +150,7 @@ trainer: geo_mask_low: 0.99 # used if sequence_mask_metric = "product" - # values around 0.5-2.0 are recommended for "sequence" sequence_mask_metric + # values around 0.5-2.0 are recommended for "product" sequence_mask_metric product_mask_high: 2.0 product_mask_low: 0.5 diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index 503cfcce4..566e6233a 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -74,11 +74,12 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: ret = {} for k, v in data.items(): options = ["min", "max", "mean"] - for op in options: - if op in k: - op = op + detected_op = op + for option in options: + if option in k: + detected_op = option break - ret[k] = self.all_reduce(v, op) + ret[k] = self.all_reduce(v, detected_op) return ret else: is_tensor = True diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 1655dc85d..8c3039c54 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -308,9 +308,9 @@ def validate_cfg(cfg: DictConfig): # Validate sequence_mask_metric if sequence_mask_metric: assert sequence_mask_metric in [ - "sequence", + "product", "geometric", - ], f"`sequence_mask_metric` must be 'sequence', or 'geometric', got {sequence_mask_metric}" + ], f"`sequence_mask_metric` must be 'product', or 'geometric', got {sequence_mask_metric}" # Ensure logprobs are enabled for rollout correction if cfg.generator.sampling_params.logprobs is None: diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index c10eaa9a2..e308b4526 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -598,7 +598,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # attach response_length status["response_length"] = micro_buffer[i]["num_actions"] - status_mean = all_reduce_metrics(status, self.strategy) + status = all_reduce_metrics(status, self.strategy) status_list.append(status) for k, v in status.items(): diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index a4abce5b8..f3aa2a0e8 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -815,8 +815,8 @@ def test_compute_sequence_mask_geometric_rejects(): assert metrics["geo_sequence_mask_under_low_ratio"] == 0.0 -def test_compute_sequence_mask_sequence(): - """Tests sequence sequence mask computation.""" +def test_compute_sequence_mask_product(): + """Tests product sequence mask computation.""" device = "cpu" # Token log ratios: [0.2, 0.1, 0.0] -> sum = 0.3, seq ratio = exp(0.3) ≈ 1.35 @@ -844,8 +844,8 @@ def test_compute_sequence_mask_sequence(): assert metrics["product_sequence_mask_under_low_ratio"] == 0.0 -def test_compute_sequence_mask_sequence_rejects_by_seq_ratio(): - """Tests product sequence mask rejects when sequence ratio is out of bounds.""" +def test_compute_sequence_mask_product_rejects_by_seq_ratio(): + """Tests product sequence mask rejects when product ratio is out of bounds.""" device = "cpu" # Token log ratios: [1.0, 1.0, 1.0] -> sum = 3.0, seq ratio = exp(3.0) ≈ 20.09 From d72d9c656b6191f51eb8f90d48e97b654ea2fdab Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 13 Jan 2026 00:54:26 +0000 Subject: [PATCH 21/23] x --- skyrl-train/examples/megatron/run_megatron.sh | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron.sh b/skyrl-train/examples/megatron/run_megatron.sh index 1ac4f2ca5..550fa8149 100644 --- a/skyrl-train/examples/megatron/run_megatron.sh +++ b/skyrl-train/examples/megatron/run_megatron.sh @@ -22,23 +22,11 @@ ENABLE_TORCH_PROFILER=false RANKS_TO_PROFILE="[0]" SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" -TIS_RATIO_TYPE="sequence" -TIS_RATIO_HIGH=2.0 -SEQUENCE_MASK_METRIC="geometric" -SEQUENCE_MASK_HIGH=1.02 -SEQUENCE_MASK_LOW=0.98 - uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ trainer.algorithm.advantage_estimator="grpo" \ - trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_RATIO_TYPE \ - trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_RATIO_HIGH \ - trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=$TIS_RATIO_HIGH \ - trainer.algorithm.off_policy_correction.sequence_mask_metric=$SEQUENCE_MASK_METRIC \ - trainer.algorithm.off_policy_correction.geo_mask_high=$SEQUENCE_MASK_HIGH \ - trainer.algorithm.off_policy_correction.geo_mask_low=$SEQUENCE_MASK_LOW \ trainer.policy.model.path=$MODEL_NAME \ trainer.placement.colocate_all=true \ trainer.strategy=megatron \ From db76d011a47bc90ba11dd1f82063952f9ca21b33 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 13 Jan 2026 01:08:40 +0000 Subject: [PATCH 22/23] gemini: --- .../megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh | 8 ++++---- skyrl-train/skyrl_train/distributed/strategy.py | 11 +---------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh index 035d5a88e..d9d9c90e3 100644 --- a/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh +++ b/skyrl-train/examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh @@ -7,7 +7,7 @@ set -x # bash examples/megatron/run_megatron_dapo_qwen3_30b_a3b_lora.sh MODEL_NAME="Qwen/Qwen3-30B-A3B-Base" -DATA_DIR="/mnt/cluster_storage/data/dapo" +DATA_DIR="$HOME/data/dapo" TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet" TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet" NUM_NODES=2 @@ -127,10 +127,10 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \ generator.gpu_memory_utilization=0.7 \ trainer.logger="$LOGGER" \ trainer.project_name="dapo_aime" \ - trainer.run_name="dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${sequence_mask_metric}_${geo_mask_high}_${geo_mask_low}" \ - trainer.export_path="$HOME/exports/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${sequence_mask_metric}_${geo_mask_high}_${geo_mask_low}" \ + trainer.run_name="dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}" \ + trainer.export_path="$HOME/exports/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}" \ trainer.hf_save_interval=300 \ trainer.resume_mode=latest \ trainer.max_ckpts_to_keep=3 \ - trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}_seq_is_${TIS_RATIO_TYPE}_${TIS_IMP_RATIO_CAP}_rej_${sequence_mask_metric}_${geo_mask_high}_${geo_mask_low}" \ + trainer.ckpt_path="$HOME/ckpts/dapo_qwen3_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}_lora_rank${LORA_RANK}_alpha${LORA_ALPHA}" \ $@ \ No newline at end of file diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index 566e6233a..fb8269e97 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -71,16 +71,7 @@ def all_reduce(self, data: DataT, op="mean") -> DataT: """Perform all_reduce across all processes""" assert op in ("mean", "max", "sum", "min") if isinstance(data, dict): - ret = {} - for k, v in data.items(): - options = ["min", "max", "mean"] - detected_op = op - for option in options: - if option in k: - detected_op = option - break - ret[k] = self.all_reduce(v, detected_op) - return ret + return {k: self.all_reduce(v, op) for k, v in data.items()} else: is_tensor = True if not isinstance(data, torch.Tensor): From ac0659caa7a7b7872e4bfa5ff22c3e545bbfbc82 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Tue, 13 Jan 2026 01:16:29 +0000 Subject: [PATCH 23/23] x --- skyrl-train/examples/megatron/run_megatron.sh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/skyrl-train/examples/megatron/run_megatron.sh b/skyrl-train/examples/megatron/run_megatron.sh index 550fa8149..cf0d9f9ed 100644 --- a/skyrl-train/examples/megatron/run_megatron.sh +++ b/skyrl-train/examples/megatron/run_megatron.sh @@ -13,8 +13,8 @@ MODEL_NAME="Qwen/Qwen3-0.6B" INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron -MEGATRON_TP=1 -MEGATRON_PP=1 +MEGATRON_TP=2 +MEGATRON_PP=2 MEGATRON_CP=1 # torch profiler config @@ -22,7 +22,6 @@ ENABLE_TORCH_PROFILER=false RANKS_TO_PROFILE="[0]" SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}" - uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ data.train_data="['$DATA_DIR/train.parquet']" \ data.val_data="['$DATA_DIR/validation.parquet']" \ @@ -49,8 +48,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \ trainer.eval_before_train=false \ trainer.eval_interval=5 \ trainer.update_epochs_per_batch=1 \ - trainer.train_batch_size=64 \ - trainer.policy_mini_batch_size=16 \ + trainer.train_batch_size=128 \ + trainer.policy_mini_batch_size=64 \ trainer.micro_forward_batch_size_per_gpu=4 \ trainer.micro_train_batch_size_per_gpu=4 \ trainer.ckpt_interval=10 \