Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ class ModelConfig:
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None

# rope config
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None


@dataclass
class InferenceModelConfig:
Expand Down Expand Up @@ -498,6 +502,10 @@ class InferenceModelConfig:
lora_modules: Optional[List[Dict]] = None
lora_kwargs: Optional[dict] = field(default_factory=dict)

# ! DO NOT SET, rope config
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None


@dataclass
class AlgorithmConfig:
Expand Down Expand Up @@ -1190,12 +1198,14 @@ def check_and_update(self) -> Config: # noqa: C901
"max_response_tokens",
"min_response_tokens",
]
for args in ["model_path"] + rollout_args + length_args:
rope_args = ["rope_scaling", "rope_theta"]
model_args = rollout_args + length_args + rope_args
for args in ["model_path"] + model_args:
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
for args in rollout_args + length_args:
for args in model_args:
set_if_none(aux_model, args, getattr(self.model, args))

# for lora configs
Expand Down
6 changes: 6 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def __init__(
max_model_len = config.max_model_len
self.enable_lora = config.enable_lora
self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None)
rope_kwargs = {
key: getattr(config, key)
for key in ["rope_scaling", "rope_theta"]
if getattr(config, key) is not None
}
engine_args = vllm.AsyncEngineArgs(
model=config.model_path,
enforce_eager=config.enforce_eager,
Expand All @@ -101,6 +106,7 @@ def __init__(
disable_log_stats=True,
enable_lora=config.enable_lora,
logprobs_mode="processed_logprobs",
**rope_kwargs,
**config.lora_kwargs,
)
if get_vllm_version() > parse_version("0.10.0"):
Expand Down
6 changes: 6 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class ActorModel:
lora_alpha: int = 32
target_modules: Optional[str] = "all-linear"

# rope configs
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None


@dataclass
class Optim:
Expand Down Expand Up @@ -412,6 +416,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
# Actor / Rollout Config
self.actor_rollout_ref.model.path = config.model.model_path
self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template
self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling
self.actor_rollout_ref.model.rope_theta = config.model.rope_theta
self.actor_rollout_ref.actor.optim.total_training_steps = self.trainer.total_training_steps
self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size
self.actor_rollout_ref.rollout.temperature = (
Expand Down
6 changes: 6 additions & 0 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def _build_model_optimizer( # noqa: C901
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
)

# patch for rope
if self.config.model.rope_scaling is not None:
actor_model_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling)
if self.config.model.rope_theta is not None:
actor_model_config.rope_theta = self.config.model.rope_theta

# patch for kimi-vl
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
actor_model_config.text_config.topk_method = "greedy"
Expand Down