Skip to content

Commit 7ef71a2

Browse files
committed
Add rope_scaling and rope_theta to config
1 parent ba33438 commit 7ef71a2

File tree

4 files changed

+30
-2
lines changed

4 files changed

+30
-2
lines changed

trinity/common/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,10 @@ class ModelConfig:
437437
fully_sharded_loras: bool = False
438438
max_cpu_loras: Optional[int] = None
439439

440+
# rope config
441+
rope_scaling: Optional[dict] = None
442+
rope_theta: Optional[float] = None
443+
440444

441445
@dataclass
442446
class InferenceModelConfig:
@@ -498,6 +502,10 @@ class InferenceModelConfig:
498502
lora_modules: Optional[List[Dict]] = None
499503
lora_kwargs: Optional[dict] = field(default_factory=dict)
500504

505+
# ! DO NOT SET, rope config
506+
rope_scaling: Optional[dict] = None
507+
rope_theta: Optional[float] = None
508+
501509

502510
@dataclass
503511
class AlgorithmConfig:
@@ -1190,12 +1198,14 @@ def check_and_update(self) -> Config: # noqa: C901
11901198
"max_response_tokens",
11911199
"min_response_tokens",
11921200
]
1193-
for args in ["model_path"] + rollout_args + length_args:
1201+
rope_args = ["rope_scaling", "rope_theta"]
1202+
model_args = rollout_args + length_args + rope_args
1203+
for args in ["model_path"] + model_args:
11941204
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
11951205
for aux_model in self.explorer.auxiliary_models:
11961206
if not aux_model.model_path:
11971207
raise ValueError("auxiliary model's model_path is required.")
1198-
for args in rollout_args + length_args:
1208+
for args in model_args:
11991209
set_if_none(aux_model, args, getattr(self.model, args))
12001210

12011211
# for lora configs

trinity/common/models/vllm_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def __init__(
7777
max_model_len = config.max_model_len
7878
self.enable_lora = config.enable_lora
7979
self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None)
80+
rope_kwargs = {
81+
key: getattr(config, key)
82+
for key in ["rope_scaling", "rope_theta"]
83+
if getattr(config, key) is not None
84+
}
8085
engine_args = vllm.AsyncEngineArgs(
8186
model=config.model_path,
8287
enforce_eager=config.enforce_eager,
@@ -101,6 +106,7 @@ def __init__(
101106
disable_log_stats=True,
102107
enable_lora=config.enable_lora,
103108
logprobs_mode="processed_logprobs",
109+
**rope_kwargs,
104110
**config.lora_kwargs,
105111
)
106112
if get_vllm_version() > parse_version("0.10.0"):

trinity/common/verl_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class ActorModel:
4040
lora_alpha: int = 32
4141
target_modules: Optional[str] = "all-linear"
4242

43+
# rope configs
44+
rope_scaling: Optional[dict] = None
45+
rope_theta: Optional[float] = None
46+
4347

4448
@dataclass
4549
class Optim:
@@ -412,6 +416,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
412416
# Actor / Rollout Config
413417
self.actor_rollout_ref.model.path = config.model.model_path
414418
self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template
419+
self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling
420+
self.actor_rollout_ref.model.rope_theta = config.model.rope_theta
415421
self.actor_rollout_ref.actor.optim.total_training_steps = self.trainer.total_training_steps
416422
self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size
417423
self.actor_rollout_ref.rollout.temperature = (

trinity/trainer/verl/fsdp_workers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ def _build_model_optimizer( # noqa: C901
257257
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
258258
)
259259

260+
# patch for rope
261+
if self.config.model.rope_scaling is not None:
262+
actor_model_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling)
263+
if self.config.model.rope_theta is not None:
264+
actor_model_config.rope_theta = self.config.model.rope_theta
265+
260266
# patch for kimi-vl
261267
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
262268
actor_model_config.text_config.topk_method = "greedy"

0 commit comments

Comments
 (0)