Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f033e65
x
erictang000 Jan 7, 2026
0b236fe
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 7, 2026
3f3b759
x
erictang000 Jan 7, 2026
29efd6f
x
erictang000 Jan 7, 2026
1520157
x
erictang000 Jan 7, 2026
45a59c2
x
erictang000 Jan 7, 2026
ce01bb2
fix tests and add rollout correction to other loss types
erictang000 Jan 8, 2026
abac800
add metrics
erictang000 Jan 8, 2026
2dc7364
propagate metrics up and refactor how we do metric reductions for max…
erictang000 Jan 8, 2026
349369d
make default null and propagate megatron metrics
erictang000 Jan 8, 2026
f3f7054
x:
erictang000 Jan 8, 2026
c45c130
Merge branch 'rollout_correction' of https://github.com/erictang000/S…
erictang000 Jan 8, 2026
63d38c5
big cleanup - remove clip_ratio return (fix custom algorithms stuff),…
erictang000 Jan 8, 2026
7e83c10
x
erictang000 Jan 8, 2026
cf042fc
renaming
erictang000 Jan 8, 2026
cef7121
x
erictang000 Jan 8, 2026
9485bdd
x
erictang000 Jan 8, 2026
9e11eda
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 12, 2026
c06747c
x
erictang000 Jan 12, 2026
0b5ebfd
x
erictang000 Jan 12, 2026
0697957
x
erictang000 Jan 13, 2026
6b9e1e4
add docs
erictang000 Jan 13, 2026
46b6fe5
x
erictang000 Jan 13, 2026
d72d9c6
x
erictang000 Jan 13, 2026
db76d01
gemini:
erictang000 Jan 13, 2026
ac0659c
x
erictang000 Jan 13, 2026
7ddb85f
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 21, 2026
6c8d084
Merge branch 'main' of https://github.com/erictang000/SkyRL into roll…
erictang000 Jan 21, 2026
08c3625
x
erictang000 Jan 21, 2026
2bca41f
x
erictang000 Jan 21, 2026
0c8789a
add unit test for metrics all reduce
CharlieFRuan Jan 23, 2026
8eb436a
lint
CharlieFRuan Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions skyrl-train/docs/algorithms/custom_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ Similarly, you can register custom policy loss functions:

.. code-block:: python

from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry
from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry, LossMetrics

@register_policy_loss("reinforce")
def compute_reinforce_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None):
# Your custom policy loss implementation (like REINFORCE)
loss = (-log_probs * advantages).mean()
# return loss and clip ratio
return loss, 0.0
# return loss and loss metrics
return loss, LossMetrics(clip_ratio=0.0)

Registry Ray Distribution
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
80 changes: 72 additions & 8 deletions skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,45 @@ Algorithm Configuration
# dual clip parameters
clip_ratio_c: 3.0

# To be deprecated in favor of off_policy_correction.tis_ratio_type = "token"
Copy link
Collaborator

@CharlieFRuan CharlieFRuan Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think off_policy_correction deserves a documentation page under the Algorithms section. We can learn from veRL's and give some canonical pre-set example, and some intuitions on when to use which. Especially these configs are kind of hierarchical (token_tis_ratio_clip_high only applicable when tis_ratio_type is token). We can do it in a followup PR.

Perhaps we can tell users what is the basic way of doing TIS (token-level), which only involves two configs. Then if they're advanced enough they can refer to the blogs and further tune the configs.

From my understanding of these two blogs:

The best config seems to be token, and geometric (figures 16-18)? Especially for long-horizon tool call? If that's the impression you got from the blogs, we should make a comment so users can pick that "preset"

# and "token_tis_ratio_clip_high"
use_tis: false
tis_imp_ratio_cap: -1.0

# references
# - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
# - https://fengyao.notion.site/off-policy-rl
off_policy_correction:
# type of importance sampling ratio to use for ppo loss correction
# here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy})
tis_ratio_type: null # null, "token", "sequence"

# used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type
token_tis_ratio_clip_high: 2.0
# used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type
sequence_tis_ratio_clip_high: 5.0

# method of masking out sequences with cumulative importance sampling ratios outside the cap
# "product" masks out sequences with product of importance ratios outside the cap
# "geometric" masks out sequences with geometric mean of importance ratios outside the cap
sequence_mask_metric: null # null, "product", "geometric"

# used if sequence_mask_metric = "geometric"
# values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch
geo_mask_high: 1.01
geo_mask_low: 0.99

# used if sequence_mask_metric = "product"
# values around 0.5-2.0 are recommended for "product" sequence_mask_metric
product_mask_high: 2.0
product_mask_low: 0.5

# separate from sequence_mask_metric and tis_ratio_type
# if any off_policy_correction is enabled, masks out sequences with any token having importance ratio
# far outside an acceptable range (low and high thresholds)
outlier_token_is_threshold_low: 1e-4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we set the default value to null for outerlier_token_is_xxx? I know the current values are really high, but I feel like null and no outerlier masking is a better default. Seems to align with veRL's approach

https://verl.readthedocs.io/en/latest/examples/config.html

   rollout_correction:
     rollout_is: token # IS weights: token/sequence/null
     rollout_is_threshold: 2.0 # Upper threshold for IS weights
     rollout_rs: null # Rejection sampling: token/sequence/geometric/null
     rollout_rs_threshold: null # RS upper threshold
     rollout_rs_threshold_lower: null # RS lower threshold
     rollout_token_veto_threshold: null # Per-token veto (null to disable)

Not sure if you'd need to make changes on the implementation side to treat None as not masking

outlier_token_is_threshold_high: 100

# clip-cov parameters (only used when policy_loss_type: "clip_cov")
clip_cov:
clip_ratio: 0.0002 # fraction of tokens to clip based on covariance
Expand All @@ -413,10 +452,6 @@ Algorithm Configuration
type: null # filter (DAPO), replace (POLARIS/WebSailor), or null
max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever
min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only)

# Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl
use_tis: false
tis_imp_ratio_cap: -1.0

# SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347)
sapo:
Expand Down Expand Up @@ -466,8 +501,8 @@ Algorithm Configuration
- ``algorithm.dynamic_sampling.type``: Type of dynamic sampling to use. Currently, we support ``filter`` (`DAPO <https://dapo-sia.github.io/>`_), ``replace`` (`POLARIS <https://hkunlp.github.io/blog/2025/Polaris/>`_ / `WebSailor <https://arxiv.org/abs/2507.02592>`_), or ``null`` for no dynamic sampling.
- ``algorithm.dynamic_sampling.max_sample_batches``: Maximum number of batches to sample before stopping. Set to ``-1`` to sample forever.
- ``algorithm.dynamic_sampling.min_replace_ratio``: Minimum proportion of good samples with which to replace bad samples for ``replace`` strategy.
- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog <https://fengyao.notion.site/off-policy-rl>`_.
- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS.
- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog <https://fengyao.notion.site/off-policy-rl>`_. This flag is to be deprecated, use ``off_policy_correction.tis_ratio_type = "token"`` instead.
- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS. This flag is to be deprecated, use ``off_policy_correction.token_tis_ratio_clip_high`` instead.
- ``algorithm.clip_cov``: Clip-Cov parameters (only used when ``policy_loss_type`` is ``clip_cov``):

- ``clip_ratio``: Fraction of tokens to clip based on covariance values.
Expand All @@ -489,6 +524,35 @@ Algorithm Configuration
- ``tau_pos``: Temperature for gating function for tokens with positive advantages.
- ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages.

Off Policy Correction Configuration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's cite the blogpost here as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on how you add the separate correction doc page (see other comment). But it'd be easier for the user if we can do the following. Basically help the uesrs understand each config (3 groups of them) one-by-one by pointing them to other resources.

1. Group these three together

  • algorithm.off_policy_correction.tis_ratio_type
  • algorithm.off_policy_correction.token_tis_ratio_clip_high
  • algorithm.off_policy_correction.sequence_tis_ratio_clip_high

and tell them:

2. Then group these together

sequence_mask_metric: null # null, "product", "geometric"
geo_mask_high: 1.01
geo_mask_low: 0.99
product_mask_high: 2.0
product_mask_low: 0.5

3. Then group the outlier threshold together

outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100

other remarks

Then pointing to our implementation would also be helpful. Namely the rollout_corrections.py or whatever name you decided in the end

- ``algorithm.off_policy_correction``: Off policy correction configuration. See the full configuration below

.. code-block:: yaml

off_policy_correction:
tis_ratio_type: null # null, "token", "sequence"
token_tis_ratio_clip_high: 2.0
sequence_tis_ratio_clip_high: 5.0
sequence_mask_metric: null # null, "product", "geometric"
geo_mask_high: 1.01
geo_mask_low: 0.99
product_mask_high: 2.0
product_mask_low: 0.5
outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100

- ``algorithm.off_policy_correction.tis_ratio_type``: Type of importance sampling ratio to use for ppo loss correction. Options include: ``null``, ``token``, ``sequence``.
- ``algorithm.off_policy_correction.token_tis_ratio_clip_high``: Cap parameter for "token" tis_ratio_type.
- ``algorithm.off_policy_correction.sequence_tis_ratio_clip_high``: Cap parameter for "sequence" tis_ratio_type.
- ``algorithm.off_policy_correction.sequence_mask_metric``: Method of masking out sequences with cumulative importance sampling ratios outside the cap. Options include: ``null``, ``product``, ``geometric``.
- ``algorithm.off_policy_correction.geo_mask_high``: High threshold for "geometric" sequence_mask_metric.
- ``algorithm.off_policy_correction.geo_mask_low``: Low threshold for "geometric" sequence_mask_metric.
- ``algorithm.off_policy_correction.product_mask_high``: High threshold for "product" sequence_mask_metric.
- ``algorithm.off_policy_correction.product_mask_low``: Low threshold for "product" sequence_mask_metric.
- ``algorithm.off_policy_correction.outlier_token_is_threshold_low``: Low threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds).
- ``algorithm.off_policy_correction.outlier_token_is_threshold_high``: High threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds).

Policy Loss Formulation
~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -502,7 +566,7 @@ It can be helpful to understand the final loss formulation to see how the differ
advantages: torch.Tensor,
config: DictConfig, # trainer.algorithm config
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, LossMetrics]:

ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
Expand All @@ -515,7 +579,7 @@ It can be helpful to understand the final loss formulation to see how the differ
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
loss = reduce_loss(loss, loss_mask, config.loss_reduction)
return loss, clip_ratio
return loss, LossMetrics(clip_ratio=clip_ratio)


Generator Configuration
Expand Down
6 changes: 3 additions & 3 deletions skyrl-train/docs/examples/flash_rl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ We highlight some important training parameters configured for FlashRL from our
DATA_DIR="$HOME/data/dapo"

# TIS parameters
USE_TIS=true
TIS_TYPE=token
TIS_IMP_RATIO_CAP=8.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 -m examples.flash_rl.main_dapo_flashrl \
...
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
generator.sampling_params.logprobs=0 \
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from omegaconf import DictConfig
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
from skyrl_train.utils.ppo_utils import PolicyLossRegistry
from skyrl_train.utils.ppo_utils import PolicyLossRegistry, LossMetrics


# Example of custom policy loss: "reinforce"
Expand All @@ -27,7 +27,7 @@ def compute_reinforce_policy_loss(
loss = (-log_probs * advantages).mean()

# Return loss and dummy clip_ratio (no clipping in REINFORCE)
return loss, 0.0
return loss, LossMetrics(clip_ratio=0.0)


# Register the custom policy loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ MAX_RESPONSE_LENGTH=1024

CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt"

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
Expand All @@ -53,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with v
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024

CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt"

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
Expand All @@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024

CKPT_PATH="$HOME/ckpts/gsm8k_32B_ckpt"

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
Expand All @@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-32B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ EVAL_TOP_P=0.7
CLIP_RATIO_C=10.0
MAX_RESPONSE_LENGTH=4096

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0
uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/dapo-math-17k-cleaned.parquet']" \
data.val_data="['$DATA_DIR/aime-2024-cleaned.parquet']" \
Expand All @@ -50,8 +52,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MAX_RESPONSE_LENGTH=20480
MAX_PROMPT_LENGTH=2048

TIS_IMP_RATIO_CAP=8.0
USE_TIS=true
TIS_TYPE=token
LOGPROBS=0

CKPT_PATH="$HOME/ckpts/dapo_32b_ckpt"
Expand All @@ -57,8 +57,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-32B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
8 changes: 4 additions & 4 deletions skyrl-train/examples/fully_async/async_run_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ set -x
: "${MAX_STALENESS_STEPS:=4}"
: "${NUM_PARALLEL_GENERATION_WORKERS:=$(( MINI_BATCH_SIZE * (MAX_STALENESS_STEPS + 1) ))}"

USE_TIS=true
TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

RUN_NAME=gsm8k-async-qwen2.5_1.5B-useTIS_${USE_TIS}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen
RUN_NAME=gsm8k-async-qwen2.5_1.5B-TIS_TYPE_${TIS_TYPE}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen

uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \
trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \
trainer.algorithm.advantage_estimator="grpo" \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=false \
trainer.strategy=fsdp2 \
Expand Down
9 changes: 5 additions & 4 deletions skyrl-train/examples/megatron/run_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ MODEL_NAME="Qwen/Qwen3-0.6B"

INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron

MEGATRON_TP=2
MEGATRON_PP=2
MEGATRON_TP=1
MEGATRON_PP=1
MEGATRON_CP=1

# torch profiler config
ENABLE_TORCH_PROFILER=false
RANKS_TO_PROFILE="[0]"
SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}"


uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
Expand All @@ -48,8 +49,8 @@ uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=128 \
trainer.policy_mini_batch_size=64 \
trainer.train_batch_size=64 \
trainer.policy_mini_batch_size=16 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=10 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ MEGATRON_ETP=1

# TIS parameters
TIS_IMP_RATIO_CAP=2.0
USE_TIS=true
TIS_TYPE=token

uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
Expand Down Expand Up @@ -83,8 +83,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
Expand Down
Loading