diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 48487c4c1..3963c6a45 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -75,6 +75,49 @@ General Training Configuration .. tip:: If you're facing issues with tuning the right values for ``micro_train_batch_size_per_gpu``, ``policy_mini_batch_size`` and ``micro_forward_batch_size_per_gpu``, see ``utils/utils.py::validate_batch_sizes`` for details on constraints. + +RoPE Configuration +------------------ + +.. code-block:: yaml + + # RoPE (Rotary Position Embedding) configuration + rope_parameters: + rope_type: null # ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'] + rope_theta: null + factor: null + original_max_position_embeddings: null + attention_factor: null + beta_fast: null + beta_slow: null + short_factor: null + long_factor: null + low_freq_factor: null + high_freq_factor: null + + # Note: rope_scaling and rope_theta are deprecated, use rope_parameters instead. + rope_scaling: null + rope_theta: null + +- ``rope_parameters``: Configuration for Rotary Position Embedding (RoPE). This allows you to configure different RoPE scaling strategies for extending context length. See `Hugging Face RoPE utils documentation `_ for more details. + - ``rope_type``: The sub-variant of RoPE to use. Can be one of [`default`, `linear`, `dynamic`, `yarn`, `longrope`, `llama3`], with `default` being the original RoPE implementation. + - ``rope_theta``: The base period of the RoPE embeddings. + - ``factor``: (optional) Scaling factor for RoPE, used with all rope types except ``default``. For most types, setting this to ``x`` allows the model to handle sequences up to ``x`` times longer than the original maximum length. + - ``original_max_position_embeddings``: (optional) Original max position embeddings before scaling. Used with ``dynamic``, ``longrope``, and ``llama3`` rope types. + - ``attention_factor``: (optional) RoPE attention scaling factor used with ``yarn`` and ``longrope`` rope types. If unset, defaults are inferred from ``factor``. + - ``beta_fast``: (optional) RoPE parameter for ``yarn``. Controls fast boundary for extrapolation. Defaults to ``32`` if unset. + - ``beta_slow``: (optional) RoPE parameter for ``yarn``. Controls slow boundary for interpolation. Defaults to ``1`` if unset. + - ``short_factor``: (optional) Only for ``longrope``. Scaling factors for short contexts. Must match hidden size divided by number of attention heads divided by 2. + - ``long_factor``: (optional) Only for ``longrope``. Scaling factors for long contexts. Must match hidden size divided by number of attention heads divided by 2. + - ``low_freq_factor``: (optional) Only for ``llama3``. Scaling factor applied to low-frequency RoPE components. + - ``high_freq_factor``: (optional) Only for ``llama3``. Scaling factor applied to high-frequency RoPE components. + +- ``rope_scaling``: (Deprecated) Legacy RoPE scaling configuration. Use ``rope_parameters`` instead. +- ``rope_theta``: (Deprecated) Legacy RoPE theta configuration. Use ``rope_parameters.rope_theta`` instead. + +.. note:: + The generator can optionally use different RoPE parameters by setting ``generator.rope_parameters`` (which defaults to ``${trainer.rope_parameters}``). + Evaluation Configuration ------------------------------ .. code-block:: yaml diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index a38d0a688..b1004cb53 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -190,13 +190,24 @@ trainer: dump_data_batch: false dump_eval_results: true - # YaRN: + # RoPE (Rotary Position Embedding) configuration + rope_parameters: + rope_type: null # ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'] + rope_theta: null + factor: null + original_max_position_embeddings: null + attention_factor: null + beta_fast: null + beta_slow: null + short_factor: null + long_factor: null + low_freq_factor: null + high_freq_factor: null + + # Note: rope_scaling and rope_theta are deprecated, use rope_parameters instead. See https://huggingface.co/docs/transformers/main/en/internal/rope_utils for more details + rope_scaling: null rope_theta: null - # rope_scaling: - # rope_type: yarn - # factor: 1.0 - # original_max_position_embeddings: 32768 step_wise_training: false @@ -297,6 +308,8 @@ generator: # rope parameters, can be optionally different from the trainer , useful in some cases like with thinking models. rope_scaling: ${trainer.rope_scaling} rope_theta: ${trainer.rope_theta} + rope_parameters: ${trainer.rope_parameters} + environment: env_class: "gsm8k" diff --git a/skyrl-train/skyrl_train/entrypoints/main_base.py b/skyrl-train/skyrl_train/entrypoints/main_base.py index 0fb2dc4b5..08e82fb38 100644 --- a/skyrl-train/skyrl_train/entrypoints/main_base.py +++ b/skyrl-train/skyrl_train/entrypoints/main_base.py @@ -58,6 +58,7 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p "tokenizer": tokenizer, "backend": cfg.generator.backend, "engine_init_kwargs": cfg.generator.engine_init_kwargs, + "rope_parameters": OmegaConf.to_container(cfg.generator.rope_parameters, resolve=True), } # Conditionally add LoRA parameters if LoRA is enabled @@ -77,11 +78,6 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p ) engine_kwargs["enforce_eager"] = False - if (rope_scaling := cfg.generator.get("rope_scaling", None)) is not None: - engine_kwargs["rope_scaling"] = rope_scaling - if (rope_theta := cfg.generator.get("rope_theta", None)) is not None: - engine_kwargs["rope_theta"] = rope_theta - return create_ray_wrapped_inference_engines(**engine_kwargs) diff --git a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 59f94beac..b222af476 100644 --- a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -92,8 +92,7 @@ def create_ray_wrapped_inference_engines( max_loras=1, fully_sharded_loras=False, engine_init_kwargs: Dict[str, Any] = {}, - rope_scaling: Dict[str, Any] = {}, - rope_theta: float | None = None, + rope_parameters: Dict[str, Any] = {}, ) -> List[InferenceEngineInterface]: """ Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface instances. @@ -155,18 +154,26 @@ def create_ray_wrapped_inference_engines( } rope_engine_kwargs = {} - if rope_scaling: - rope_engine_kwargs["rope_scaling"] = rope_scaling - if "max_model_len" not in engine_init_kwargs: - rope_factor = rope_scaling.get("factor", None) - rope_max_pos = rope_scaling.get("original_max_position_embeddings", None) - assert rope_factor is not None, "Please provide rope scaling `factor` to compute model max length" - assert ( - rope_max_pos is not None - ), "Please provide rope `original_max_position_embeddings` to compute model max length" - rope_engine_kwargs["max_model_len"] = int(rope_factor * rope_max_pos) - if rope_theta is not None: - rope_engine_kwargs["rope_theta"] = rope_theta + if rope_parameters: + rope_theta = rope_parameters.get("rope_theta", None) + rope_type = rope_parameters.get("rope_type", None) + + # TODO(dev): remove this once vLLM supports updated rope_parameters, for now we use the old rope config format (rope_scaling, rope_theta) in vLLM. + if rope_type: + rope_scaling = rope_parameters.copy() + rope_scaling.pop("rope_theta", None) + rope_engine_kwargs["rope_scaling"] = rope_scaling + + if "max_model_len" not in engine_init_kwargs: + rope_factor = rope_scaling.get("factor", None) + rope_max_pos = rope_scaling.get("original_max_position_embeddings", None) + assert ( + rope_factor is not None and rope_max_pos is not None + ), "Both `factor` and `original_max_position_embeddings` must be provided for rope scaling when `max_model_len` is not set." + rope_engine_kwargs["max_model_len"] = int(rope_factor * rope_max_pos) + + if rope_theta is not None: + rope_engine_kwargs["rope_theta"] = rope_theta # Launch one actor per DP rank for dp_rank in range(data_parallel_size): diff --git a/skyrl-train/skyrl_train/model_wrapper.py b/skyrl-train/skyrl_train/model_wrapper.py index 028cb86ac..7963c8894 100644 --- a/skyrl-train/skyrl_train/model_wrapper.py +++ b/skyrl-train/skyrl_train/model_wrapper.py @@ -63,8 +63,7 @@ def __init__( sequence_parallel_size=1, use_sample_packing: bool = False, use_torch_compile: bool = False, - rope_scaling: Dict[str, Any] = {}, - rope_theta: float | None = None, + rope_parameters: Dict[str, Any] = {}, **kwargs, ) -> None: super().__init__() @@ -111,20 +110,21 @@ def __init__( else: model_class = AutoModelForCausalLM - rope_scaling_kwargs = {} - if rope_scaling: - rope_scaling_kwargs["rope_scaling"] = rope_scaling - if rope_theta: - rope_scaling_kwargs["rope_theta"] = rope_theta + # TODO(dev): check if more elegant solution rather than config first and set rope_parameters on it + config = AutoConfig.from_pretrained( + pretrain_or_model, + trust_remote_code=True, + ) + config.rope_parameters = rope_parameters self.model = model_class.from_pretrained( pretrain_or_model, + config=config, trust_remote_code=True, attn_implementation=self.attn_implementation, quantization_config=nf4_config, torch_dtype=torch.bfloat16 if bf16 else torch.float32, device_map=device_map, - **rope_scaling_kwargs, ) # gpt oss @@ -534,6 +534,7 @@ def get_llm_for_sequence_regression( device_map=None, sequence_parallel_size=1, use_sample_packing: bool = False, + rope_parameters: Dict[str, Any] = {}, **kwargs, ) -> nn.Module: """Get transformer with a sequence classification head on top (linear layer). @@ -545,6 +546,7 @@ def get_llm_for_sequence_regression( use_flash_attention_2 (bool, optional): Whether use Flash Attention 2.0. Defaults to False. ds_config (dict, optional): Deepspeed config, used to automatically splitting the model onto multiple gpus during from_pretrained when ZeRO-3 enabled. Defaults to None. + rope_parameters (Dict[str, Any], optional): RoPE configuration parameters. Defaults to {}. Returns: nn.Module: pretrained transformer model. @@ -552,6 +554,7 @@ def get_llm_for_sequence_regression( assert model_type == "critic", f"Only model_type critic is supported, got: {model_type}." config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + config.rope_parameters = rope_parameters config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" base_class = AutoModel._model_mapping[type(config)] diff --git a/skyrl-train/skyrl_train/utils/trainer_utils.py b/skyrl-train/skyrl_train/utils/trainer_utils.py index 169a924f2..b7a038553 100644 --- a/skyrl-train/skyrl_train/utils/trainer_utils.py +++ b/skyrl-train/skyrl_train/utils/trainer_utils.py @@ -660,15 +660,43 @@ def build_dataloader( return dataloader -def get_rope_scaling_config(trainer_cfg: DictConfig) -> dict[str, Any]: - if "rope_scaling" not in trainer_cfg: - return {} - if trainer_cfg.rope_scaling is None: - return None - return OmegaConf.to_container(trainer_cfg.rope_scaling) +def get_rope_parameters_config(trainer_cfg: DictConfig) -> dict[str, Any]: + rope_scaling = trainer_cfg.get("rope_scaling", None) + rope_theta = trainer_cfg.get("rope_theta", None) + has_old_config = rope_scaling is not None or rope_theta is not None + + rope_parameters_new = trainer_cfg.get("rope_parameters", None) + has_new_config = rope_parameters_new is not None + + if has_old_config and has_new_config: + logger.warning( + "Both old ('rope_scaling', 'rope_theta') and new ('rope_parameters') RoPE configs are provided. " + "Prioritizing the old config for backward compatibility. Please migrate to 'rope_parameters'." + ) + if has_old_config: + rope_parameters = {} + if rope_scaling is not None: + rope_scaling_dict = ( + OmegaConf.to_container(rope_scaling, resolve=True) + if isinstance(rope_scaling, DictConfig) + else rope_scaling + ) + if isinstance(rope_scaling_dict, dict): + rope_parameters.update(rope_scaling_dict) + else: + logger.warning(f"Ignoring 'rope_scaling' as it is not a dictionary. Found: {rope_scaling_dict}") + if rope_theta is not None: + rope_parameters["rope_theta"] = rope_theta + return rope_parameters + + elif has_new_config: + new_params = OmegaConf.to_container(rope_parameters_new, resolve=True) + if isinstance(new_params, dict): + return new_params + if new_params is not None: + logger.warning(f"Ignoring 'rope_parameters' as it is not a dictionary. Found: {new_params}") + return {} -def get_rope_theta_config(trainer_cfg: DictConfig) -> int | None: - if "rope_theta" not in trainer_cfg: - return None - return trainer_cfg.rope_theta + else: + return {} diff --git a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py index 3f2a70225..fdd12896b 100644 --- a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py +++ b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py @@ -12,7 +12,8 @@ from skyrl_train.model_wrapper import get_llm_for_sequence_regression, HFModelWrapper from skyrl_train.distributed.deepspeed_strategy import DeepspeedStrategy from skyrl_train.utils import get_physical_gpu_id -from skyrl_train.utils.trainer_utils import get_rope_scaling_config, get_rope_theta_config + +from skyrl_train.utils.trainer_utils import get_rope_parameters_config from skyrl_train.utils.utils import str_to_torch_dtype from skyrl_train.workers.worker import ( PolicyWorkerBase, @@ -64,8 +65,7 @@ def init_model(self, model_id_or_path, num_training_steps: int = None): sequence_parallel_size=self.sequence_parallel_size, use_sample_packing=self.cfg.trainer.use_sample_packing, use_torch_compile=self.cfg.trainer.policy.use_torch_compile, - rope_scaling=get_rope_scaling_config(self.cfg.trainer), - rope_theta=get_rope_theta_config(self.cfg.trainer), + rope_parameters=get_rope_parameters_config(self.cfg.trainer), ) # configure optimizer @@ -296,6 +296,7 @@ def init_model(self, model_id_or_path, num_training_steps: int = None): init_value_head=self.cfg.trainer.policy.model.path == self.cfg.trainer.critic.model.path, sequence_parallel_size=self.sequence_parallel_size, use_sample_packing=self.cfg.trainer.use_sample_packing, + rope_parameters=get_rope_parameters_config(self.cfg.trainer), ) # configure optimizer critic_optim = strategy.create_optimizer( @@ -357,8 +358,7 @@ def init_model(self, model_path): ds_config=strategy.get_ds_eval_config(), sequence_parallel_size=self.sequence_parallel_size, use_sample_packing=self.cfg.trainer.use_sample_packing, - rope_scaling=get_rope_scaling_config(self.cfg.trainer), - rope_theta=get_rope_theta_config(self.cfg.trainer), + rope_parameters=get_rope_parameters_config(self.cfg.trainer), ) self._seq_parallel_monkey_patch(model=wrapped_model.model) diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index e3fd0aaf8..d2b3ff23a 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -1,7 +1,7 @@ import asyncio from typing import Dict, List -from skyrl_train.utils.trainer_utils import get_rope_scaling_config, get_rope_theta_config +from skyrl_train.utils.trainer_utils import get_rope_parameters_config import ray import torch import torch.distributed @@ -77,8 +77,7 @@ def init_model(self, model_path, num_training_steps: int = None): sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size, use_sample_packing=self.cfg.trainer.use_sample_packing, use_torch_compile=self.cfg.trainer.policy.use_torch_compile, - rope_scaling=get_rope_scaling_config(self.cfg.trainer), - rope_theta=get_rope_theta_config(self.cfg.trainer), + rope_parameters=get_rope_parameters_config(self.cfg.trainer), ) # in-place patch self._seq_parallel_monkey_patch(model=wrapped_model.model) @@ -342,6 +341,7 @@ def init_model(self, model_path, num_training_steps: int = None): init_value_head=self.cfg.trainer.policy.model.path == self.cfg.trainer.critic.model.path, sequence_parallel_size=self.cfg.trainer.critic.sequence_parallel_size, use_sample_packing=self.cfg.trainer.use_sample_packing, + rope_parameters=get_rope_parameters_config(self.cfg.trainer), ) self._seq_parallel_monkey_patch(model=critic, use_parent_class=True) @@ -404,8 +404,7 @@ def init_model(self, model_path): bf16=self.cfg.trainer.bf16, sequence_parallel_size=self.cfg.trainer.ref.sequence_parallel_size, use_sample_packing=self.cfg.trainer.use_sample_packing, - rope_scaling=get_rope_scaling_config(self.cfg.trainer), - rope_theta=get_rope_theta_config(self.cfg.trainer), + rope_parameters=get_rope_parameters_config(self.cfg.trainer), ) self._seq_parallel_monkey_patch(model=wrapped_model.model) diff --git a/skyrl-train/tests/gpu/gpu_ci/distributed/test_fsdp_strategy.py b/skyrl-train/tests/gpu/gpu_ci/distributed/test_fsdp_strategy.py index 72d87e505..22fbc7974 100644 --- a/skyrl-train/tests/gpu/gpu_ci/distributed/test_fsdp_strategy.py +++ b/skyrl-train/tests/gpu/gpu_ci/distributed/test_fsdp_strategy.py @@ -9,7 +9,7 @@ from torch.distributed.distributed_c10d import init_process_group from skyrl_train.distributed.fsdp_strategy import FSDPStrategy from skyrl_train.config.utils import get_default_config -from skyrl_train.utils.trainer_utils import get_rope_scaling_config, get_rope_theta_config +from skyrl_train.utils.trainer_utils import get_rope_parameters_config from skyrl_train.utils.utils import get_free_port MODEL_NAME = "llamafactory/tiny-random-Llama-3" @@ -48,8 +48,7 @@ def test_fsdp1_wrap_policy(): sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size, use_sample_packing=cfg.trainer.use_sample_packing, use_torch_compile=cfg.trainer.policy.use_torch_compile, - rope_scaling=get_rope_scaling_config(cfg.trainer), - rope_theta=get_rope_theta_config(cfg.trainer), + rope_parameters=get_rope_parameters_config(cfg.trainer), ) model, _, _ = strategy.prepare( diff --git a/skyrl-train/tests/gpu/gpu_ci/test_model_wrapper.py b/skyrl-train/tests/gpu/gpu_ci/test_model_wrapper.py index 5839e7c6c..4acdb1e51 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_model_wrapper.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_model_wrapper.py @@ -236,3 +236,80 @@ def test_actor_entropy_consistency_sample_packing(): assert torch.allclose( output_no_pack["entropy"], output_pack["entropy"] ), "Entropy with sample packing doesn't match entropy without sample packing" + + +def test_rope_parameters_set_on_config(): + """Tests that RoPE parameters are properly set on the model config before AutoModel init""" + + # Test with various RoPE parameter configurations + test_cases = [ + { + "rope_parameters": { + "rope_type": "linear", + "factor": 2.0, + "rope_theta": 10000.0, + }, + }, + { + "rope_parameters": { + "rope_type": "yarn", + "factor": 4.0, + "rope_theta": 20000.0, + "beta_fast": 32, + "beta_slow": 1, + }, + }, + { + "rope_parameters": { + "rope_type": "dynamic", + "factor": 8.0, + "original_max_position_embeddings": 2048, + }, + }, + { + "rope_parameters": {}, + }, + ] + + for test_case in test_cases: + rope_params = test_case["rope_parameters"] + + with ( + patch("skyrl_train.model_wrapper.AutoConfig") as mock_config_class, + patch("skyrl_train.model_wrapper.AutoModelForCausalLM") as mock_model_class, + ): + + mock_config = MagicMock() + mock_config_class.from_pretrained.return_value = mock_config + + mock_model = MagicMock() + mock_model_class.from_pretrained.return_value = mock_model + + # Initialize HFModelWrapper with rope_parameters + # We don't need to use the wrapper, just need to trigger initialization to test the mocks + _ = HFModelWrapper( + pretrain_or_model=MODEL_NAME, + use_flash_attention_2=False, + bf16=False, + rope_parameters=rope_params, + ) + + # Verify AutoConfig.from_pretrained was called + mock_config_class.from_pretrained.assert_called_once_with( + MODEL_NAME, + trust_remote_code=True, + ) + + # Verify rope_parameters were set on the config + assert hasattr(mock_config, "rope_parameters"), "rope_parameters should be set on config" + assert ( + mock_config.rope_parameters == rope_params + ), f"rope_parameters mismatch. Expected {rope_params}, got {mock_config.rope_parameters}" + + # Verify AutoModelForCausalLM.from_pretrained was called with the config + mock_model_class.from_pretrained.assert_called_once() + call_kwargs = mock_model_class.from_pretrained.call_args + assert call_kwargs.kwargs["config"] == mock_config, "Config should be passed to from_pretrained" + assert ( + call_kwargs.kwargs["config"].rope_parameters == rope_params + ), "rope_parameters should be preserved in the config passed to from_pretrained"