Skip to content
Closed
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e7d9d44
done
devpatelio Nov 20, 2025
38312c2
rename params to parameters for rope for parity
devpatelio Nov 20, 2025
f4ec82d
Update skyrl-train/skyrl_train/utils/trainer_utils.py
devpatelio Nov 20, 2025
45034fb
support old param use
devpatelio Nov 20, 2025
ac6d2fc
doc updates and remove excess comments
devpatelio Nov 20, 2025
5dc16a0
remove comments
devpatelio Nov 20, 2025
8954231
remove comments
devpatelio Nov 20, 2025
a7105cf
Apply suggestions from code review
devpatelio Nov 20, 2025
3f6512a
done
devpatelio Nov 20, 2025
4c39a7e
some changes
devpatelio Nov 25, 2025
270e0f7
merge changes
devpatelio Nov 25, 2025
8c1dd19
fixes for rope config
devpatelio Nov 25, 2025
8623973
Pass generator, not trainer rope configuration (they're the same by d…
devpatelio Nov 25, 2025
9f8b08b
better user logging for clear rope behaviour
devpatelio Nov 25, 2025
3c5884e
linter
devpatelio Nov 25, 2025
210609d
update gitignore
devpatelio Nov 25, 2025
ee0259e
Apply suggestions from code review
SumanthRH Nov 25, 2025
eb1c0f2
Update skyrl-train/skyrl_train/entrypoints/main_base.py
devpatelio Nov 25, 2025
536e5ef
Update skyrl-train/skyrl_train/utils/trainer_utils.py
devpatelio Nov 25, 2025
fd18186
Update skyrl-train/skyrl_train/utils/trainer_utils.py
devpatelio Nov 25, 2025
32250c6
Update skyrl-train/skyrl_train/utils/trainer_utils.py
devpatelio Nov 25, 2025
22bed1c
Update skyrl-train/skyrl_train/utils/trainer_utils.py
devpatelio Nov 25, 2025
62228b9
add test to model wrapper
devpatelio Nov 25, 2025
eaedbf6
linter
devpatelio Nov 25, 2025
15ab029
goarugh
devpatelio Nov 25, 2025
cbd75bf
done
devpatelio Nov 25, 2025
a34c01d
return empty dict
devpatelio Nov 25, 2025
0b8a1b9
done
devpatelio Nov 25, 2025
b62f97a
Merge branch 'main' into devpatel/skyrl-rope-support
devpatelio Nov 26, 2025
63aeb06
some changes
devpatelio Nov 26, 2025
82dd877
rm stepweise training
devpatelio Nov 26, 2025
bb9774e
revert gsm8k
devpatelio Nov 26, 2025
a3963a3
revert gsm8k
devpatelio Nov 26, 2025
52a8959
piped rope config to critic model calls
devpatelio Nov 27, 2025
7ad5745
Update skyrl-train/skyrl_train/utils/trainer_utils.py
devpatelio Nov 27, 2025
585c1fd
change base config
devpatelio Dec 8, 2025
0ccb00a
updated docs
devpatelio Dec 8, 2025
f418a44
Merge branch 'main' of https://github.com/erictang000/SkyRL into dev_…
erictang000 Dec 8, 2025
0d4131c
x
erictang000 Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,48 @@ Global LoRA Configuration
- ``target_modules``: Specifies which modules to apply LoRA to. Set to ``"all-linear"`` to apply LoRA to all linear layers, or provide a list of specific module names.
- ``exclude_modules``: List of modules to exclude from LoRA application. Set to ``null`` to exclude none.

RoPE Configuration
------------------

.. code-block:: yaml

# RoPE (Rotary Position Embedding) configuration
rope_parameters:
rope_type: null # ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3']
rope_theta: null
factor: null
original_max_position_embeddings: null
attention_factor: null
beta_fast: null
beta_slow: null
short_factor: null
long_factor: null
low_freq_factor: null
high_freq_factor: null

# Note: rope_scaling and rope_theta are deprecated, use rope_parameters instead.
rope_scaling: null
rope_theta: null

- ``rope_parameters``: Configuration for Rotary Position Embedding (RoPE). This allows you to configure different RoPE scaling strategies for extending context length. See `Hugging Face RoPE utils documentation <https://huggingface.co/docs/transformers/main/en/internal/rope_utils>`_ for more details.
- ``rope_type``: The sub-variant of RoPE to use. Can be one of [`default`, `linear`, `dynamic`, `yarn`, `longrope`, `llama3`], with `default` being the original RoPE implementation.
- ``rope_theta``: The base period of the RoPE embeddings.
- ``factor``: (optional) Scaling factor for RoPE, used with all rope types except ``default``. For most types, setting this to ``x`` allows the model to handle sequences up to ``x`` times longer than the original maximum length.
- ``original_max_position_embeddings``: (optional) Original max position embeddings before scaling. Used with ``dynamic``, ``longrope``, and ``llama3`` rope types.
- ``attention_factor``: (optional) RoPE attention scaling factor used with ``yarn`` and ``longrope`` rope types. If unset, defaults are inferred from ``factor``.
- ``beta_fast``: (optional) RoPE parameter for ``yarn``. Controls fast boundary for extrapolation. Defaults to ``32`` if unset.
- ``beta_slow``: (optional) RoPE parameter for ``yarn``. Controls slow boundary for interpolation. Defaults to ``1`` if unset.
- ``short_factor``: (optional) Only for ``longrope``. Scaling factors for short contexts. Must match hidden size divided by number of attention heads divided by 2.
- ``long_factor``: (optional) Only for ``longrope``. Scaling factors for long contexts. Must match hidden size divided by number of attention heads divided by 2.
- ``low_freq_factor``: (optional) Only for ``llama3``. Scaling factor applied to low-frequency RoPE components.
- ``high_freq_factor``: (optional) Only for ``llama3``. Scaling factor applied to high-frequency RoPE components.

- ``rope_scaling``: (Deprecated) Legacy RoPE scaling configuration. Use ``rope_parameters`` instead.
- ``rope_theta``: (Deprecated) Legacy RoPE theta configuration. Use ``rope_parameters.rope_theta`` instead.

.. note::
The generator can optionally use different RoPE parameters by setting ``generator.rope_parameters`` (which defaults to ``${trainer.rope_parameters}``).

Evaluation Configuration
------------------------------
.. code-block:: yaml
Expand Down
20 changes: 18 additions & 2 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,29 @@ trainer:
dump_data_batch: false
dump_eval_results: true

# YaRN:
# RoPE (Rotary Position Embedding) configuration
rope_parameters:
rope_type: null # ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3']
rope_theta: null
factor: null
original_max_position_embeddings: null
attention_factor: null
beta_fast: null
beta_slow: null
short_factor: null
long_factor: null
low_freq_factor: null
high_freq_factor: null

# Note: rope_scaling and rope_theta are deprecated, use rope_parameters instead. See https://huggingface.co/docs/transformers/main/en/internal/rope_utils for more details

rope_scaling: null
rope_theta: null
# rope_scaling:
# rope_type: yarn
# factor: 1.0
# original_max_position_embeddings: 32768


generator:
model_name: ${trainer.policy.model.path}
model_dtype: "bfloat16" # should match dtype for inference engine
Expand Down Expand Up @@ -292,6 +306,8 @@ generator:
# rope parameters, can be optionally different from the trainer , useful in some cases like with thinking models.
rope_scaling: ${trainer.rope_scaling}
rope_theta: ${trainer.rope_theta}
rope_parameters: ${trainer.rope_parameters}


environment:
env_class: "gsm8k"
Expand Down
6 changes: 1 addition & 5 deletions skyrl-train/skyrl_train/entrypoints/main_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p
"tokenizer": tokenizer,
"backend": cfg.generator.backend,
"engine_init_kwargs": cfg.generator.engine_init_kwargs,
"rope_parameters": OmegaConf.to_container(cfg.generator.rope_parameters, resolve=True),
}

# Conditionally add LoRA parameters if LoRA is enabled
Expand All @@ -77,11 +78,6 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p
)
engine_kwargs["enforce_eager"] = False

if (rope_scaling := cfg.generator.get("rope_scaling", None)) is not None:
engine_kwargs["rope_scaling"] = rope_scaling
if (rope_theta := cfg.generator.get("rope_theta", None)) is not None:
engine_kwargs["rope_theta"] = rope_theta

return create_ray_wrapped_inference_engines(**engine_kwargs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def create_ray_wrapped_inference_engines(
max_loras=1,
fully_sharded_loras=False,
engine_init_kwargs: Dict[str, Any] = {},
rope_scaling: Dict[str, Any] = {},
rope_theta: float | None = None,
rope_parameters: Dict[str, Any] = {},
) -> List[InferenceEngineInterface]:
"""
Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface instances.
Expand Down Expand Up @@ -155,18 +154,26 @@ def create_ray_wrapped_inference_engines(
}

rope_engine_kwargs = {}
if rope_scaling:
rope_engine_kwargs["rope_scaling"] = rope_scaling
if "max_model_len" not in engine_init_kwargs:
rope_factor = rope_scaling.get("factor", None)
rope_max_pos = rope_scaling.get("original_max_position_embeddings", None)
assert rope_factor is not None, "Please provide rope scaling `factor` to compute model max length"
assert (
rope_max_pos is not None
), "Please provide rope `original_max_position_embeddings` to compute model max length"
rope_engine_kwargs["max_model_len"] = int(rope_factor * rope_max_pos)
if rope_theta is not None:
rope_engine_kwargs["rope_theta"] = rope_theta
if rope_parameters:
rope_theta = rope_parameters.get("rope_theta", None)
rope_type = rope_parameters.get("rope_type", None)

# TODO(dev): remove this once vLLM supports updated rope_parameters, for now we use the old rope config format (rope_scaling, rope_theta) in vLLM.
if rope_type:
rope_scaling = rope_parameters.copy()
rope_scaling.pop("rope_theta", None)
rope_engine_kwargs["rope_scaling"] = rope_scaling

if "max_model_len" not in engine_init_kwargs:
rope_factor = rope_scaling.get("factor", None)
rope_max_pos = rope_scaling.get("original_max_position_embeddings", None)
assert (
rope_factor is not None and rope_max_pos is not None
), "Both `factor` and `original_max_position_embeddings` must be provided for rope scaling when `max_model_len` is not set."
rope_engine_kwargs["max_model_len"] = int(rope_factor * rope_max_pos)

if rope_theta is not None:
rope_engine_kwargs["rope_theta"] = rope_theta

# Launch one actor per DP rank
for dp_rank in range(data_parallel_size):
Expand Down
19 changes: 11 additions & 8 deletions skyrl-train/skyrl_train/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def __init__(
sequence_parallel_size=1,
use_sample_packing: bool = False,
use_torch_compile: bool = False,
rope_scaling: Dict[str, Any] = {},
rope_theta: float | None = None,
rope_parameters: Dict[str, Any] = {},
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -111,20 +110,21 @@ def __init__(
else:
model_class = AutoModelForCausalLM

rope_scaling_kwargs = {}
if rope_scaling:
rope_scaling_kwargs["rope_scaling"] = rope_scaling
if rope_theta:
rope_scaling_kwargs["rope_theta"] = rope_theta
# TODO(dev): check if more elegant solution rather than config first and set rope_parameters on it
config = AutoConfig.from_pretrained(
pretrain_or_model,
trust_remote_code=True,
)
config.rope_parameters = rope_parameters

self.model = model_class.from_pretrained(
pretrain_or_model,
config=config,
trust_remote_code=True,
attn_implementation=self.attn_implementation,
quantization_config=nf4_config,
torch_dtype=torch.bfloat16 if bf16 else torch.float32,
device_map=device_map,
**rope_scaling_kwargs,
)

# gpt oss
Expand Down Expand Up @@ -534,6 +534,7 @@ def get_llm_for_sequence_regression(
device_map=None,
sequence_parallel_size=1,
use_sample_packing: bool = False,
rope_parameters: Dict[str, Any] = {},
**kwargs,
) -> nn.Module:
"""Get transformer with a sequence classification head on top (linear layer).
Expand All @@ -545,13 +546,15 @@ def get_llm_for_sequence_regression(
use_flash_attention_2 (bool, optional): Whether use Flash Attention 2.0. Defaults to False.
ds_config (dict, optional): Deepspeed config, used to automatically splitting the model onto
multiple gpus during from_pretrained when ZeRO-3 enabled. Defaults to None.
rope_parameters (Dict[str, Any], optional): RoPE configuration parameters. Defaults to {}.

Returns:
nn.Module: pretrained transformer model.
"""
assert model_type == "critic", f"Only model_type critic is supported, got: {model_type}."

config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
config.rope_parameters = rope_parameters
config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager"

base_class = AutoModel._model_mapping[type(config)]
Expand Down
44 changes: 33 additions & 11 deletions skyrl-train/skyrl_train/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,15 +660,37 @@ def build_dataloader(
return dataloader


def get_rope_scaling_config(trainer_cfg: DictConfig) -> dict[str, Any]:
if "rope_scaling" not in trainer_cfg:
return {}
if trainer_cfg.rope_scaling is None:
return None
return OmegaConf.to_container(trainer_cfg.rope_scaling)

def get_rope_parameters_config(trainer_cfg: DictConfig) -> dict[str, Any]:
rope_scaling = trainer_cfg.get("rope_scaling", None)
rope_theta = trainer_cfg.get("rope_theta", None)
has_old_config = rope_scaling is not None or rope_theta is not None

rope_parameters_new = trainer_cfg.get("rope_parameters", None)
has_new_config = rope_parameters_new is not None

if has_old_config and has_new_config:
logger.warning(
"Both old ('rope_scaling', 'rope_theta') and new ('rope_parameters') RoPE configs are provided. "
"Prioritizing the old config for backward compatibility. Please migrate to 'rope_parameters'."
)

def get_rope_theta_config(trainer_cfg: DictConfig) -> int | None:
if "rope_theta" not in trainer_cfg:
return None
return trainer_cfg.rope_theta
if has_old_config:
rope_parameters = {}
if rope_scaling is not None:
rope_scaling_dict = (
OmegaConf.to_container(rope_scaling, resolve=True)
if isinstance(rope_scaling, DictConfig)
else rope_scaling
)
Comment on lines +680 to +684
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conditional expression to create rope_scaling_dict can be simplified. OmegaConf.to_container handles non-DictConfig inputs correctly by returning them as-is. You can simplify this to an unconditional call, which also makes it more consistent with how rope_parameters_new is handled later in the function.

            rope_scaling_dict = OmegaConf.to_container(rope_scaling, resolve=True)

if isinstance(rope_scaling_dict, dict):
rope_parameters.update(rope_scaling_dict)
else:
logger.warning(f"Ignoring 'rope_scaling' as it is not a dictionary. Found: {rope_scaling_dict}")
if rope_theta is not None:
rope_parameters["rope_theta"] = rope_theta
return rope_parameters

elif has_new_config:
return OmegaConf.to_container(rope_parameters_new, resolve=True) or {}
else:
return {}
10 changes: 5 additions & 5 deletions skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from skyrl_train.model_wrapper import get_llm_for_sequence_regression, HFModelWrapper
from skyrl_train.distributed.deepspeed_strategy import DeepspeedStrategy
from skyrl_train.utils import get_physical_gpu_id
from skyrl_train.utils.trainer_utils import get_rope_scaling_config, get_rope_theta_config

from skyrl_train.utils.trainer_utils import get_rope_parameters_config
from skyrl_train.utils.utils import str_to_torch_dtype
from skyrl_train.workers.worker import (
PolicyWorkerBase,
Expand Down Expand Up @@ -63,8 +64,7 @@ def init_model(self, model_id_or_path, num_training_steps: int = None):
sequence_parallel_size=self.sequence_parallel_size,
use_sample_packing=self.cfg.trainer.use_sample_packing,
use_torch_compile=self.cfg.trainer.policy.use_torch_compile,
rope_scaling=get_rope_scaling_config(self.cfg.trainer),
rope_theta=get_rope_theta_config(self.cfg.trainer),
rope_parameters=get_rope_parameters_config(self.cfg.trainer),
)

# configure optimizer
Expand Down Expand Up @@ -294,6 +294,7 @@ def init_model(self, model_id_or_path, num_training_steps: int = None):
init_value_head=self.cfg.trainer.policy.model.path == self.cfg.trainer.critic.model.path,
sequence_parallel_size=self.sequence_parallel_size,
use_sample_packing=self.cfg.trainer.use_sample_packing,
rope_parameters=get_rope_parameters_config(self.cfg.trainer),
)
# configure optimizer
critic_optim = strategy.create_optimizer(
Expand Down Expand Up @@ -355,8 +356,7 @@ def init_model(self, model_path):
ds_config=strategy.get_ds_eval_config(),
sequence_parallel_size=self.sequence_parallel_size,
use_sample_packing=self.cfg.trainer.use_sample_packing,
rope_scaling=get_rope_scaling_config(self.cfg.trainer),
rope_theta=get_rope_theta_config(self.cfg.trainer),
rope_parameters=get_rope_parameters_config(self.cfg.trainer),
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

Expand Down
9 changes: 4 additions & 5 deletions skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from typing import Dict, List

from skyrl_train.utils.trainer_utils import get_rope_scaling_config, get_rope_theta_config
from skyrl_train.utils.trainer_utils import get_rope_parameters_config
import ray
import torch
import torch.distributed
Expand Down Expand Up @@ -77,8 +77,7 @@ def init_model(self, model_path, num_training_steps: int = None):
sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size,
use_sample_packing=self.cfg.trainer.use_sample_packing,
use_torch_compile=self.cfg.trainer.policy.use_torch_compile,
rope_scaling=get_rope_scaling_config(self.cfg.trainer),
rope_theta=get_rope_theta_config(self.cfg.trainer),
rope_parameters=get_rope_parameters_config(self.cfg.trainer),
)
# in-place patch
self._seq_parallel_monkey_patch(model=wrapped_model.model)
Expand Down Expand Up @@ -341,6 +340,7 @@ def init_model(self, model_path, num_training_steps: int = None):
init_value_head=self.cfg.trainer.policy.model.path == self.cfg.trainer.critic.model.path,
sequence_parallel_size=self.cfg.trainer.critic.sequence_parallel_size,
use_sample_packing=self.cfg.trainer.use_sample_packing,
rope_parameters=get_rope_parameters_config(self.cfg.trainer),
)
self._seq_parallel_monkey_patch(model=critic, use_parent_class=True)

Expand Down Expand Up @@ -403,8 +403,7 @@ def init_model(self, model_path):
bf16=self.cfg.trainer.bf16,
sequence_parallel_size=self.cfg.trainer.ref.sequence_parallel_size,
use_sample_packing=self.cfg.trainer.use_sample_packing,
rope_scaling=get_rope_scaling_config(self.cfg.trainer),
rope_theta=get_rope_theta_config(self.cfg.trainer),
rope_parameters=get_rope_parameters_config(self.cfg.trainer),
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.distributed.distributed_c10d import init_process_group
from skyrl_train.distributed.fsdp_strategy import FSDPStrategy
from skyrl_train.config.utils import get_default_config
from skyrl_train.utils.trainer_utils import get_rope_scaling_config, get_rope_theta_config
from skyrl_train.utils.trainer_utils import get_rope_parameters_config
from skyrl_train.utils.utils import get_free_port

MODEL_NAME = "llamafactory/tiny-random-Llama-3"
Expand Down Expand Up @@ -48,8 +48,7 @@ def test_fsdp1_wrap_policy():
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
use_sample_packing=cfg.trainer.use_sample_packing,
use_torch_compile=cfg.trainer.policy.use_torch_compile,
rope_scaling=get_rope_scaling_config(cfg.trainer),
rope_theta=get_rope_theta_config(cfg.trainer),
rope_parameters=get_rope_parameters_config(cfg.trainer),
)

model, _, _ = strategy.prepare(
Expand Down
Loading
Loading