Skip to content

Commit b317708

Browse files
authored
Update veRL to 0.7.0 (#471)
1 parent c5effdb commit b317708

File tree

17 files changed

+1617
-616
lines changed

17 files changed

+1617
-616
lines changed

.github/workflows/docker/docker-compose.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
services:
22
trinity-node-1:
3-
image: trinity-rft-unittest:20251225
3+
image: trinity-rft-unittest:20260115
44
pull_policy: never
55
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block"
66
environment:
@@ -30,7 +30,7 @@ services:
3030
capabilities: [gpu]
3131

3232
trinity-node-2:
33-
image: trinity-rft-unittest:20251225
33+
image: trinity-rft-unittest:20260115
3434
pull_policy: never
3535
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block"
3636
environment:

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
]
2222
requires-python = ">=3.10,<3.13"
2323
dependencies = [
24-
"verl==0.5.0",
24+
"verl==0.7.0",
2525
"ray[default]>=2.50.0",
2626
"tensordict",
2727
"wandb",
@@ -79,7 +79,9 @@ megatron = [
7979
# if you found "undefined symbol" error in transformer engine
8080
# reinstall it with --no-build-isolation and `--no-cache-dir` flag
8181
# "transformer_engine[pytorch]==2.8.0",
82-
"mbridge>=0.13.0",
82+
83+
# Install mbridge from main branch (unreleased version)
84+
"mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612",
8385
]
8486
tinker = [
8587
"tinker; python_version >= '3.11'",

tests/explorer/workflow_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,12 +731,16 @@ async def mock_get_api_server_url_remote():
731731
async def mock_get_model_version_remote():
732732
return 1
733733

734+
async def mock_get_api_key_remote():
735+
return "dummy_api_key"
736+
734737
async def mock_get_model_config_remote():
735738
return InferenceModelConfig(model_path="dummy_model")
736739

737740
model = MagicMock()
738741
model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote)
739742
model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote)
743+
model.get_api_key.remote = MagicMock(side_effect=mock_get_api_key_remote)
740744
model.get_model_config.remote = MagicMock(side_effect=mock_get_model_config_remote)
741745

742746
runner = WorkflowRunner(

trinity/common/config.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,15 @@ class OptimizerConfig:
9494
lr_warmup_steps: int = -1
9595
lr_warmup_steps_ratio: float = 0.0
9696
min_lr_ratio: Optional[float] = 0.0
97-
warmup_style: str = "constant"
97+
warmup_style: Optional[str] = None # deprecated !
98+
lr_scheduler_type: str = "constant"
9899
optimizer_type: str = "adam"
99100
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
100101
weight_decay: float = 0.01
101102
clip_grad: float = 1.0
102103
lr_warmup_init: float = 0.0
103104
lr_decay_steps: Optional[int] = None
104-
lr_decay_style: str = "constant"
105+
lr_decay_style: str = "constant" # duplicated with lr_scheduler_type in veRL
105106
min_lr: float = 0.0
106107

107108

@@ -116,6 +117,8 @@ class LoRAConfig:
116117
lora_alpha: int = 32
117118
lora_dtype: str = "auto"
118119
target_modules: str = "all-linear"
120+
exclude_modules: Optional[str] = None
121+
is_dummy: bool = False # DO NOT SET, automatically set
119122

120123

121124
@Experimental
@@ -1167,6 +1170,14 @@ def check_and_set(name, registry, args_attr):
11671170
# override loss_agg_mode in policy_loss_fn_args
11681171
self.algorithm.policy_loss_fn_args["loss_agg_mode"] = self.algorithm.loss_agg_mode # type: ignore [index]
11691172

1173+
optim_config = self.algorithm.optimizer
1174+
if optim_config.warmup_style is not None:
1175+
optim_config.lr_scheduler_type = optim_config.warmup_style
1176+
logger.warning(
1177+
"`warmup_style` is deprecated. Please use `lr_scheduler_type` instead. "
1178+
f"And `lr_scheduler_type` is set to {optim_config.lr_scheduler_type}."
1179+
)
1180+
11701181
def _check_model(self) -> None:
11711182
model = self.model
11721183
if not model.critic_model_path:
@@ -1363,16 +1374,19 @@ def _check_explorer(self) -> None:
13631374
self.explorer.rollout_model.enable_lora = True
13641375
if len(self.model.lora_configs) > 1:
13651376
raise ValueError("Only one lora adapter is supported for now.")
1366-
if self.model.lora_configs[0].path is None:
1377+
lora_config = self.model.lora_configs[0]
1378+
if lora_config.path is None:
13671379
logger.info("Creating dummy lora, since no lora_path is provided.")
13681380
lora_path = create_dummy_lora(
13691381
model_path=self.model.model_path,
13701382
checkpoint_job_dir=self.checkpoint_job_dir,
1371-
lora_rank=self.model.lora_configs[0].lora_rank,
1372-
lora_alpha=self.model.lora_configs[0].lora_alpha,
1373-
target_modules=self.model.lora_configs[0].target_modules,
1383+
lora_rank=lora_config.lora_rank,
1384+
lora_alpha=lora_config.lora_alpha,
1385+
target_modules=lora_config.target_modules,
1386+
exclude_modules=lora_config.exclude_modules,
13741387
)
1375-
self.model.lora_configs[0].path = lora_path
1388+
lora_config.path = lora_path
1389+
lora_config.is_dummy = True
13761390
self.explorer.rollout_model.lora_modules = [
13771391
{
13781392
"lora_int_id": i + 1,

trinity/common/models/vllm_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class WorkerExtension:
1414
def apply_patches(self):
1515
"""Apply necessary patches to vLLM."""
16-
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
16+
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
1717

1818
patch_vllm_moe_model_weight_loader(self.model_runner.model)
1919
patch_vllm_prompt_logprobs(self.model_runner)

trinity/common/verl_config.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Any, Dict, List, Optional
55

66
from omegaconf import OmegaConf
7+
from verl.workers.config import PolicyLossConfig, RouterReplayConfig
78

9+
from trinity.algorithm import ALGORITHM_TYPE
810
from trinity.common.config import Config, SynchronizerConfig, set_if_none
911
from trinity.common.constants import EXPLORER_NAME
1012
from trinity.utils.log import get_logger
@@ -41,6 +43,8 @@ class ActorModel:
4143
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
4244
lora_alpha: int = 32
4345
target_modules: Optional[str] = "all-linear"
46+
exclude_modules: Optional[str] = None
47+
lora_adapter_path: Optional[str] = None
4448

4549
# rope configs
4650
rope_scaling: Optional[dict] = None
@@ -51,14 +55,15 @@ class ActorModel:
5155
class Optim:
5256
# For actor, most fields are set in algorithm.optimizer
5357
# For critic, you can set trainer_config.critic.optim
58+
optimizer: str = "adam"
59+
optimizer_impl: str = "torch.optim"
5460
lr: float = 1e-6
5561
lr_warmup_steps: int = -1
5662
lr_warmup_steps_ratio: float = 0.0
5763
min_lr_ratio: Optional[float] = 0.0
58-
warmup_style: str = "constant"
64+
lr_scheduler_type: str = "constant"
5965
total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps
6066
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
61-
optimizer: str = "adam"
6267
clip_grad: float = 1.0
6368
lr_warmup_init: float = 0.0
6469
lr_decay_steps: Optional[int] = None
@@ -69,6 +74,7 @@ class Optim:
6974
lr_wsd_decay_style: str = "exponential"
7075
lr_wsd_decay_steps: Optional[int] = None
7176
use_checkpoint_opt_param_scheduler: bool = False
77+
override_optimizer_config: Optional[dict] = None
7278

7379

7480
@dataclass
@@ -78,6 +84,7 @@ class WrapPolicy:
7884

7985
@dataclass
8086
class FSDPConfig:
87+
_target_: str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET
8188
param_offload: bool = False
8289
optimizer_offload: bool = False
8390
offload_policy: bool = False
@@ -92,15 +99,15 @@ class FSDPConfig:
9299
class Checkpoint:
93100
load_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
94101
save_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
95-
async_save: bool = False # do not set, async save has bug in verl megatron training
102+
async_save: bool = False # TODO: testing async save
96103

97104

98105
@dataclass
99106
class OverrideTransformerConfig:
100-
recompute_granularity: Optional[str] = None
107+
recompute_granularity: Optional[str] = "full"
101108
recompute_modules: List[str] = field(default_factory=lambda: ["core_attn"])
102-
recompute_method: Optional[str] = None
103-
recompute_num_layers: Optional[int] = None
109+
recompute_method: Optional[str] = "uniform"
110+
recompute_num_layers: Optional[int] = 1
104111

105112

106113
@dataclass
@@ -124,6 +131,8 @@ class MegatronConfig:
124131
default_factory=OverrideTransformerConfig
125132
)
126133
use_mbridge: bool = False
134+
dtype: str = "bfloat16"
135+
use_remove_padding: bool = True
127136

128137

129138
@dataclass
@@ -157,6 +166,9 @@ class Actor:
157166
profile: ProfileConfig = field(default_factory=ProfileConfig)
158167
data_loader_seed: Optional[int] = None
159168
load_weight: bool = True
169+
policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)
170+
profiler: dict = field(default_factory=dict)
171+
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
160172
# do not set
161173
loss_agg_mode: str = "token-mean"
162174
clip_ratio: float = 0.2
@@ -182,6 +194,8 @@ class Ref:
182194
megatron: MegatronConfig = field(default_factory=MegatronConfig)
183195
profile: ProfileConfig = field(default_factory=ProfileConfig)
184196
load_weight: bool = True
197+
profiler: dict = field(default_factory=dict)
198+
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
185199

186200

187201
@dataclass
@@ -214,6 +228,7 @@ class ActorRolloutRef:
214228
actor: Actor = field(default_factory=Actor)
215229
ref: Ref = field(default_factory=Ref)
216230
rollout: Rollout = field(default_factory=Rollout)
231+
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
217232
synchronizer: Optional[SynchronizerConfig] = None
218233
explorer_name: str = EXPLORER_NAME
219234

@@ -229,9 +244,14 @@ class CriticModel:
229244
use_remove_padding: bool = True
230245
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
231246

247+
# rope configs
248+
rope_scaling: Optional[dict] = None
249+
rope_theta: Optional[float] = None
250+
232251

233252
@dataclass
234253
class Critic:
254+
enable: bool = False
235255
strategy: Optional[str] = None
236256
optim: Optim = field(default_factory=Optim)
237257
model: CriticModel = field(default_factory=CriticModel)
@@ -255,7 +275,9 @@ class Critic:
255275
profile: ProfileConfig = field(default_factory=ProfileConfig)
256276
data_loader_seed: Optional[int] = None
257277
load_weight: bool = True
278+
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
258279
ray_namespace: str = "" # automatically generated
280+
profiler: dict = field(default_factory=dict)
259281

260282

261283
@dataclass
@@ -278,6 +300,7 @@ class RewardModel:
278300
use_dynamic_bsz: bool = False
279301
forward_max_token_len_per_gpu: int = 0
280302
reward_manager: str = "naive"
303+
use_reward_loop: bool = True
281304

282305

283306
@dataclass
@@ -294,8 +317,24 @@ class KL_Ctrl:
294317
target_kl: float = 0.1
295318

296319

320+
@dataclass
321+
class RolloutCorrection:
322+
rollout_is: Optional[str] = None
323+
rollout_is_threshold: float = 2.0
324+
rollout_rs: Optional[str] = None
325+
rollout_rs_threshold: Optional[float] = None
326+
rollout_rs_threshold_lower: Optional[float] = None
327+
rollout_token_veto_threshold: Optional[float] = None
328+
# Because rollout and training in Trinity runs separately,
329+
# rollout_is_batch_normalize is default to True
330+
bypass_mode: bool = True
331+
loss_type: str = "ppo_clip"
332+
rollout_is_batch_normalize: bool = False
333+
334+
297335
@dataclass
298336
class Algorithm:
337+
rollout_correction: RolloutCorrection = field(default_factory=RolloutCorrection)
299338
# ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
300339
# and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
301340
# if they are really needed (e.g., for GAE advantage/returns computation)
@@ -349,6 +388,7 @@ class veRLConfig:
349388
custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction)
350389
algorithm: Algorithm = field(default_factory=Algorithm)
351390
trainer: Trainer = field(default_factory=Trainer)
391+
global_profiler: dict = field(default_factory=dict)
352392
synchronizer: Optional[SynchronizerConfig] = None
353393
enable_preview: bool = True
354394

@@ -426,8 +466,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
426466
) # kept to pass RayPPOTrainer._validate_config
427467

428468
self.synchronizer = config.synchronizer
469+
self.actor_rollout_ref.nccl_timeout = config.synchronizer.sync_timeout
429470
self.actor_rollout_ref.synchronizer = config.synchronizer
430471
self.actor_rollout_ref.explorer_name = config.explorer.name
472+
algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type)
473+
self.critic.enable = algorithm.use_critic
474+
self.critic.nccl_timeout = config.synchronizer.sync_timeout
431475
self.critic.ray_namespace = config.synchronizer.ray_namespace
432476

433477
# Actor / Rollout Config
@@ -507,6 +551,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
507551
set_if_none(self.critic, "strategy", config.trainer.trainer_strategy)
508552
self.critic.model.path = config.model.critic_model_path
509553
self.critic.model.tokenizer_path = config.model.critic_model_path
554+
self.critic.model.rope_scaling = config.model.rope_scaling
555+
self.critic.model.rope_theta = config.model.rope_theta
510556
self.critic.ppo_mini_batch_size = config.buffer.train_batch_size
511557
self.critic.rollout_n = self.actor_rollout_ref.rollout.n
512558
self.critic.optim.total_training_steps = self.trainer.total_training_steps
@@ -542,11 +588,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
542588

543589
# LoRA related config
544590
if config.model.lora_configs is not None:
545-
self.actor_rollout_ref.model.lora_rank = config.model.lora_configs[0].lora_rank
546-
self.actor_rollout_ref.model.lora_alpha = config.model.lora_configs[0].lora_alpha
547-
self.actor_rollout_ref.model.target_modules = config.model.lora_configs[
548-
0
549-
].target_modules
591+
lora_config = config.model.lora_configs[0]
592+
actor_model_config = self.actor_rollout_ref.model
593+
for attr in ["lora_rank", "lora_alpha", "target_modules", "exclude_modules"]:
594+
setattr(actor_model_config, attr, getattr(lora_config, attr))
595+
if not lora_config.is_dummy:
596+
actor_model_config.lora_adapter_path = lora_config.path
550597
if self.actor_rollout_ref.actor.strategy not in ["fsdp", "fsdp2"]:
551598
logger.warning(
552599
f"Lora is only supported for fsdp and fsdp2, but got {self.actor_rollout_ref.actor.strategy} instead, changed to fsdp."
@@ -565,6 +612,17 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
565612
setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value)
566613
elif hasattr(self.actor_rollout_ref.actor.optim, field_name):
567614
setattr(self.actor_rollout_ref.actor.optim, field_name, field_value)
615+
# fix optimizer type for fsdp
616+
if config.trainer.trainer_strategy.startswith("fsdp"):
617+
optim_map = {
618+
"adam": "AdamW",
619+
"adamw": "AdamW",
620+
"sgd": "SGD",
621+
}
622+
actor_optim = self.actor_rollout_ref.actor.optim
623+
actor_optim.optimizer = optim_map.get(actor_optim.optimizer, actor_optim.optimizer)
624+
critic_optim = self.critic.optim
625+
critic_optim.optimizer = optim_map.get(critic_optim.optimizer, critic_optim.optimizer)
568626
self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none"
569627
self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none"
570628
# TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to

trinity/manager/config_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _expert_verl_training_part(self):
302302
def _expert_verl_actor_part(self):
303303
st.subheader("Actor Model Config")
304304

305-
self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio")
305+
self.get_configs("actor_lr", "actor_lr_scheduler_type", "actor_lr_warmup_steps_ratio")
306306

307307
self.get_configs("actor_grad_clip", "actor_ulysses_sequence_parallel_size")
308308

@@ -324,7 +324,7 @@ def _expert_verl_critic_part(self):
324324
"critic_ppo_micro_batch_size_per_gpu", "critic_ulysses_sequence_parallel_size"
325325
)
326326

327-
self.get_configs("critic_lr", "critic_warmup_style", "critic_lr_warmup_steps_ratio")
327+
self.get_configs("critic_lr", "critic_lr_scheduler_type", "critic_lr_warmup_steps_ratio")
328328

329329
self.get_configs("critic_grad_clip", "critic_cliprange_value")
330330
self.get_configs("critic_load_checkpoint", "critic_save_checkpoint")
@@ -490,7 +490,7 @@ def _generate_verl_config(self):
490490
"optim": {
491491
"lr": st.session_state["critic_lr"],
492492
"lr_warmup_steps_ratio": st.session_state["critic_lr_warmup_steps_ratio"],
493-
"warmup_style": st.session_state["critic_warmup_style"],
493+
"lr_scheduler_type": st.session_state["critic_lr_scheduler_type"],
494494
},
495495
"model": {
496496
"override_config": {},
@@ -550,7 +550,7 @@ def _gen_algorithm_config(self):
550550
optimizer_config = {
551551
"lr": st.session_state["actor_lr"],
552552
"lr_warmup_steps_ratio": st.session_state["actor_lr_warmup_steps_ratio"],
553-
"warmup_style": st.session_state["actor_warmup_style"],
553+
"lr_scheduler_type": st.session_state["actor_lr_scheduler_type"],
554554
}
555555
algorithm_config["optimizer"] = optimizer_config
556556
return algorithm_config

0 commit comments

Comments
 (0)