Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions skyrl-train/docs/algorithms/custom_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ Similarly, you can register custom policy loss functions:

.. code-block:: python

from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry
from skyrl_train.utils.ppo_utils import register_policy_loss, PolicyLossRegistry, LossMetrics

@register_policy_loss("reinforce")
def compute_reinforce_policy_loss(log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_log_probs=None):
# Your custom policy loss implementation (like REINFORCE)
loss = (-log_probs * advantages).mean()
# return loss and clip ratio
return loss, 0.0
# return loss and loss metrics
return loss, LossMetrics(clip_ratio=0.0)

Registry Ray Distribution
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
80 changes: 72 additions & 8 deletions skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,45 @@ Algorithm Configuration
# dual clip parameters
clip_ratio_c: 3.0

# To be deprecated in favor of off_policy_correction.tis_ratio_type = "token"
# and "token_tis_ratio_clip_high"
use_tis: false
tis_imp_ratio_cap: -1.0

# references
# - https://github.com/szrlee/verl/blob/yingru/rollout_correction/docs/advance/rollout_corr_math.md
# - https://fengyao.notion.site/off-policy-rl
off_policy_correction:
# type of importance sampling ratio to use for ppo loss correction
# here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy})
tis_ratio_type: null # null, "token", "sequence"

# used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type
token_tis_ratio_clip_high: 2.0
# used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type
sequence_tis_ratio_clip_high: 5.0

# method of masking out sequences with cumulative importance sampling ratios outside the cap
# "product" masks out sequences with product of importance ratios outside the cap
# "geometric" masks out sequences with geometric mean of importance ratios outside the cap
sequence_mask_metric: null # null, "product", "geometric"

# used if sequence_mask_metric = "geometric"
# values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch
geo_mask_high: 1.01
geo_mask_low: 0.99

# used if sequence_mask_metric = "product"
# values around 0.5-2.0 are recommended for "product" sequence_mask_metric
product_mask_high: 2.0
product_mask_low: 0.5

# separate from sequence_mask_metric and tis_ratio_type
# if any off_policy_correction is enabled, masks out sequences with any token having importance ratio
# far outside an acceptable range (low and high thresholds)
outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100

# clip-cov parameters (only used when policy_loss_type: "clip_cov")
clip_cov:
clip_ratio: 0.0002 # fraction of tokens to clip based on covariance
Expand All @@ -413,10 +452,6 @@ Algorithm Configuration
type: null # filter (DAPO), replace (POLARIS/WebSailor), or null
max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever
min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only)

# Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl
use_tis: false
tis_imp_ratio_cap: -1.0

# SAPO parameters (only used when policy_loss_type: "sapo") (https://arxiv.org/pdf/2511.20347)
sapo:
Expand Down Expand Up @@ -466,8 +501,8 @@ Algorithm Configuration
- ``algorithm.dynamic_sampling.type``: Type of dynamic sampling to use. Currently, we support ``filter`` (`DAPO <https://dapo-sia.github.io/>`_), ``replace`` (`POLARIS <https://hkunlp.github.io/blog/2025/Polaris/>`_ / `WebSailor <https://arxiv.org/abs/2507.02592>`_), or ``null`` for no dynamic sampling.
- ``algorithm.dynamic_sampling.max_sample_batches``: Maximum number of batches to sample before stopping. Set to ``-1`` to sample forever.
- ``algorithm.dynamic_sampling.min_replace_ratio``: Minimum proportion of good samples with which to replace bad samples for ``replace`` strategy.
- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog <https://fengyao.notion.site/off-policy-rl>`_.
- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS.
- ``algorithm.use_tis``: Whether to use Truncated Importance Sampling (TIS) as proposed in `this blog <https://fengyao.notion.site/off-policy-rl>`_. This flag is to be deprecated, use ``off_policy_correction.tis_ratio_type = "token"`` instead.
- ``algorithm.tis_imp_ratio_cap``: Cap parameter for the importance ratio in TIS. This flag is to be deprecated, use ``off_policy_correction.token_tis_ratio_clip_high`` instead.
- ``algorithm.clip_cov``: Clip-Cov parameters (only used when ``policy_loss_type`` is ``clip_cov``):

- ``clip_ratio``: Fraction of tokens to clip based on covariance values.
Expand All @@ -489,6 +524,35 @@ Algorithm Configuration
- ``tau_pos``: Temperature for gating function for tokens with positive advantages.
- ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages.

Off Policy Correction Configuration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- ``algorithm.off_policy_correction``: Off policy correction configuration. See the full configuration below

.. code-block:: yaml

off_policy_correction:
tis_ratio_type: null # null, "token", "sequence"
token_tis_ratio_clip_high: 2.0
sequence_tis_ratio_clip_high: 5.0
sequence_mask_metric: null # null, "product", "geometric"
geo_mask_high: 1.01
geo_mask_low: 0.99
product_mask_high: 2.0
product_mask_low: 0.5
outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100

- ``algorithm.off_policy_correction.tis_ratio_type``: Type of importance sampling ratio to use for ppo loss correction. Options include: ``null``, ``token``, ``sequence``.
- ``algorithm.off_policy_correction.token_tis_ratio_clip_high``: Cap parameter for "token" tis_ratio_type.
- ``algorithm.off_policy_correction.sequence_tis_ratio_clip_high``: Cap parameter for "sequence" tis_ratio_type.
- ``algorithm.off_policy_correction.sequence_mask_metric``: Method of masking out sequences with cumulative importance sampling ratios outside the cap. Options include: ``null``, ``product``, ``geometric``.
- ``algorithm.off_policy_correction.geo_mask_high``: High threshold for "geometric" sequence_mask_metric.
- ``algorithm.off_policy_correction.geo_mask_low``: Low threshold for "geometric" sequence_mask_metric.
- ``algorithm.off_policy_correction.product_mask_high``: High threshold for "product" sequence_mask_metric.
- ``algorithm.off_policy_correction.product_mask_low``: Low threshold for "product" sequence_mask_metric.
- ``algorithm.off_policy_correction.outlier_token_is_threshold_low``: Low threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds).
- ``algorithm.off_policy_correction.outlier_token_is_threshold_high``: High threshold for outlier token mask - masks out sequences with any token having importance ratio far outside an acceptable range (low and high thresholds).

Policy Loss Formulation
~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -502,7 +566,7 @@ It can be helpful to understand the final loss formulation to see how the differ
advantages: torch.Tensor,
config: DictConfig, # trainer.algorithm config
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, LossMetrics]:

ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
Expand All @@ -515,7 +579,7 @@ It can be helpful to understand the final loss formulation to see how the differ
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
loss = reduce_loss(loss, loss_mask, config.loss_reduction)
return loss, clip_ratio
return loss, LossMetrics(clip_ratio=clip_ratio)


Generator Configuration
Expand Down
6 changes: 3 additions & 3 deletions skyrl-train/docs/examples/flash_rl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ We highlight some important training parameters configured for FlashRL from our
DATA_DIR="$HOME/data/dapo"

# TIS parameters
USE_TIS=true
TIS_TYPE=token
TIS_IMP_RATIO_CAP=8.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 -m examples.flash_rl.main_dapo_flashrl \
...
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
generator.sampling_params.logprobs=0 \
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from omegaconf import DictConfig
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
from skyrl_train.utils.ppo_utils import PolicyLossRegistry
from skyrl_train.utils.ppo_utils import PolicyLossRegistry, LossMetrics


# Example of custom policy loss: "reinforce"
Expand All @@ -27,7 +27,7 @@ def compute_reinforce_policy_loss(
loss = (-log_probs * advantages).mean()

# Return loss and dummy clip_ratio (no clipping in REINFORCE)
return loss, 0.0
return loss, LossMetrics(clip_ratio=0.0)


# Register the custom policy loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ MAX_RESPONSE_LENGTH=1024

CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt"

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
Expand All @@ -53,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.fp8 --with v
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024

CKPT_PATH="$HOME/ckpts/gsm8k_0.5B_ckpt"

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
Expand All @@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ MAX_RESPONSE_LENGTH=1024

CKPT_PATH="$HOME/ckpts/gsm8k_32B_ckpt"

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
Expand All @@ -52,8 +55,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-32B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ EVAL_TOP_P=0.7
CLIP_RATIO_C=10.0
MAX_RESPONSE_LENGTH=4096

TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0
uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --with vllm@https://github.com/NovaSky-AI/SkyRL/releases/download/skyrl_train-v0.1.0/vllm-0.1.dev7509+gcc487699a.d20250821-cp312-cp312-linux_x86_64.whl --with transformers==4.53.3 -- python -m examples.flash_rl.main_dapo_flashrl \
data.train_data="['$DATA_DIR/dapo-math-17k-cleaned.parquet']" \
data.val_data="['$DATA_DIR/aime-2024-cleaned.parquet']" \
Expand All @@ -50,8 +52,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.0.5b_int8 --
trainer.algorithm.dynamic_sampling.max_sample_batches=$DYNAMIC_SAMPLING_MAX_SAMPLE_BATCHES \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MAX_RESPONSE_LENGTH=20480
MAX_PROMPT_LENGTH=2048

TIS_IMP_RATIO_CAP=8.0
USE_TIS=true
TIS_TYPE=token
LOGPROBS=0

CKPT_PATH="$HOME/ckpts/dapo_32b_ckpt"
Expand All @@ -57,8 +57,8 @@ uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.int8 --with
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-32B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
Expand Down
8 changes: 4 additions & 4 deletions skyrl-train/examples/fully_async/async_run_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ set -x
: "${MAX_STALENESS_STEPS:=4}"
: "${NUM_PARALLEL_GENERATION_WORKERS:=$(( MINI_BATCH_SIZE * (MAX_STALENESS_STEPS + 1) ))}"

USE_TIS=true
TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

RUN_NAME=gsm8k-async-qwen2.5_1.5B-useTIS_${USE_TIS}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen
RUN_NAME=gsm8k-async-qwen2.5_1.5B-TIS_TYPE_${TIS_TYPE}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen

uv run --isolated --extra $INFERENCE_BACKEND -m examples.fully_async.main_async \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \
trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \
trainer.algorithm.advantage_estimator="grpo" \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=false \
trainer.strategy=fsdp2 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ MEGATRON_ETP=1

# TIS parameters
TIS_IMP_RATIO_CAP=2.0
USE_TIS=true
TIS_TYPE=token

uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
Expand Down Expand Up @@ -83,8 +83,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ LORA_RANK=32
LORA_ALPHA=64

# TIS parameters
TIS_IMP_RATIO_CAP=2.0
USE_TIS=true
TIS_IMP_RATIO_CAP=3.0

# rollout correction parameters
TIS_RATIO_TYPE="sequence"
sequence_mask_metric="geometric"
geo_mask_high=1.05
geo_mask_low=0.95

uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
Expand Down Expand Up @@ -88,8 +93,11 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.policy.model.lora.rank=$LORA_RANK \
trainer.policy.model.lora.alpha=$LORA_ALPHA \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_RATIO_TYPE \
trainer.algorithm.off_policy_correction.sequence_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.sequence_mask_metric=$sequence_mask_metric \
trainer.algorithm.off_policy_correction.geo_mask_high=$geo_mask_high \
trainer.algorithm.off_policy_correction.geo_mask_low=$geo_mask_low \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
Expand Down
6 changes: 3 additions & 3 deletions skyrl-train/examples/megatron/run_megatron_dapo_qwen3_4b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ MEGATRON_ETP=null

# TIS parameters
TIS_IMP_RATIO_CAP=2.0
USE_TIS=true
TIS_TYPE=token

uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
Expand Down Expand Up @@ -80,8 +80,8 @@ uv run --isolated --extra mcore -m examples.algorithms.dapo.main_dapo \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.algorithm.use_tis=$USE_TIS \
trainer.algorithm.tis_imp_ratio_cap=$TIS_IMP_RATIO_CAP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
Expand Down
Loading