Skip to content

Commit e094c99

Browse files
committed
update verl to 0.7.0
1 parent 39dd1d4 commit e094c99

File tree

12 files changed

+1641
-595
lines changed

12 files changed

+1641
-595
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
]
2222
requires-python = ">=3.10,<3.13"
2323
dependencies = [
24-
"verl==0.5.0",
24+
"verl==0.7.0",
2525
"ray[default]>=2.50.0",
2626
"tensordict",
2727
"wandb",

trinity/common/config.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class LoRAConfig:
116116
lora_alpha: int = 32
117117
lora_dtype: str = "auto"
118118
target_modules: str = "all-linear"
119+
exclude_modules: Optional[str] = None
120+
is_dummy: bool = False # DO NOT SET, automatically set
119121

120122

121123
@Experimental
@@ -1356,16 +1358,19 @@ def _check_explorer(self) -> None:
13561358
self.explorer.rollout_model.enable_lora = True
13571359
if len(self.model.lora_configs) > 1:
13581360
raise ValueError("Only one lora adapter is supported for now.")
1359-
if self.model.lora_configs[0].path is None:
1361+
lora_config = self.model.lora_configs[0]
1362+
if lora_config.path is None:
13601363
logger.info("Creating dummy lora, since no lora_path is provided.")
13611364
lora_path = create_dummy_lora(
13621365
model_path=self.model.model_path,
13631366
checkpoint_job_dir=self.checkpoint_job_dir,
1364-
lora_rank=self.model.lora_configs[0].lora_rank,
1365-
lora_alpha=self.model.lora_configs[0].lora_alpha,
1366-
target_modules=self.model.lora_configs[0].target_modules,
1367+
lora_rank=lora_config.lora_rank,
1368+
lora_alpha=lora_config.lora_alpha,
1369+
target_modules=lora_config.target_modules,
1370+
exclude_modules=lora_config.exclude_modules,
13671371
)
1368-
self.model.lora_configs[0].path = lora_path
1372+
lora_config.path = lora_path
1373+
lora_config.is_dummy = True
13691374
self.explorer.rollout_model.lora_modules = [
13701375
{
13711376
"lora_int_id": i + 1,

trinity/common/verl_config.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Any, Dict, List, Optional
55

66
from omegaconf import OmegaConf
7+
from verl.workers.config import PolicyLossConfig, RouterReplayConfig
78

9+
from trinity.algorithm import ALGORITHM_TYPE
810
from trinity.common.config import Config, SynchronizerConfig, set_if_none
911
from trinity.common.constants import EXPLORER_NAME
1012
from trinity.utils.log import get_logger
@@ -41,6 +43,8 @@ class ActorModel:
4143
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
4244
lora_alpha: int = 32
4345
target_modules: Optional[str] = "all-linear"
46+
exclude_modules: Optional[str] = None
47+
lora_adapter_path: Optional[str] = None
4448

4549
# rope configs
4650
rope_scaling: Optional[dict] = None
@@ -51,14 +55,15 @@ class ActorModel:
5155
class Optim:
5256
# For actor, most fields are set in algorithm.optimizer
5357
# For critic, you can set trainer_config.critic.optim
58+
optimizer: str = "AdamW"
59+
optimizer_impl: str = "torch.optim"
5460
lr: float = 1e-6
5561
lr_warmup_steps: int = -1
5662
lr_warmup_steps_ratio: float = 0.0
5763
min_lr_ratio: Optional[float] = 0.0
5864
warmup_style: str = "constant"
5965
total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps
6066
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
61-
optimizer: str = "adam"
6267
clip_grad: float = 1.0
6368
lr_warmup_init: float = 0.0
6469
lr_decay_steps: Optional[int] = None
@@ -69,6 +74,7 @@ class Optim:
6974
lr_wsd_decay_style: str = "exponential"
7075
lr_wsd_decay_steps: Optional[int] = None
7176
use_checkpoint_opt_param_scheduler: bool = False
77+
override_optimizer_config: Optional[dict] = None
7278

7379

7480
@dataclass
@@ -78,6 +84,7 @@ class WrapPolicy:
7884

7985
@dataclass
8086
class FSDPConfig:
87+
_target_: str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET
8188
param_offload: bool = False
8289
optimizer_offload: bool = False
8390
offload_policy: bool = False
@@ -92,7 +99,7 @@ class FSDPConfig:
9299
class Checkpoint:
93100
load_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
94101
save_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
95-
async_save: bool = False # do not set, async save has bug in verl megatron training
102+
async_save: bool = False # TODO: testing async save
96103

97104

98105
@dataclass
@@ -124,6 +131,8 @@ class MegatronConfig:
124131
default_factory=OverrideTransformerConfig
125132
)
126133
use_mbridge: bool = False
134+
dtype: str = "bfloat16"
135+
use_remove_padding: bool = True
127136

128137

129138
@dataclass
@@ -157,6 +166,9 @@ class Actor:
157166
profile: ProfileConfig = field(default_factory=ProfileConfig)
158167
data_loader_seed: Optional[int] = None
159168
load_weight: bool = True
169+
policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)
170+
profiler: dict = field(default_factory=dict)
171+
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
160172
# do not set
161173
loss_agg_mode: str = "token-mean"
162174
clip_ratio: float = 0.2
@@ -182,6 +194,8 @@ class Ref:
182194
megatron: MegatronConfig = field(default_factory=MegatronConfig)
183195
profile: ProfileConfig = field(default_factory=ProfileConfig)
184196
load_weight: bool = True
197+
profiler: dict = field(default_factory=dict)
198+
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
185199

186200

187201
@dataclass
@@ -214,6 +228,7 @@ class ActorRolloutRef:
214228
actor: Actor = field(default_factory=Actor)
215229
ref: Ref = field(default_factory=Ref)
216230
rollout: Rollout = field(default_factory=Rollout)
231+
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
217232
synchronizer: Optional[SynchronizerConfig] = None
218233
explorer_name: str = EXPLORER_NAME
219234

@@ -232,6 +247,7 @@ class CriticModel:
232247

233248
@dataclass
234249
class Critic:
250+
enable: bool = False
235251
strategy: Optional[str] = None
236252
optim: Optim = field(default_factory=Optim)
237253
model: CriticModel = field(default_factory=CriticModel)
@@ -255,7 +271,9 @@ class Critic:
255271
profile: ProfileConfig = field(default_factory=ProfileConfig)
256272
data_loader_seed: Optional[int] = None
257273
load_weight: bool = True
274+
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
258275
ray_namespace: str = "" # automatically generated
276+
profiler: dict = field(default_factory=dict)
259277

260278

261279
@dataclass
@@ -278,6 +296,7 @@ class RewardModel:
278296
use_dynamic_bsz: bool = False
279297
forward_max_token_len_per_gpu: int = 0
280298
reward_manager: str = "naive"
299+
use_reward_loop: bool = True
281300

282301

283302
@dataclass
@@ -294,8 +313,24 @@ class KL_Ctrl:
294313
target_kl: float = 0.1
295314

296315

316+
@dataclass
317+
class RolloutCorrection:
318+
rollout_is: Optional[str] = None
319+
rollout_is_threshold: float = 2.0
320+
rollout_rs: Optional[str] = None
321+
rollout_rs_threshold: Optional[float] = None
322+
rollout_rs_threshold_lower: Optional[float] = None
323+
rollout_token_veto_threshold: Optional[float] = None
324+
# Because rollout and training in Trinity runs separately,
325+
# rollout_is_batch_normalize is default to True
326+
bypass_mode: bool = True
327+
loss_type: str = "ppo_clip"
328+
rollout_is_batch_normalize: bool = False
329+
330+
297331
@dataclass
298332
class Algorithm:
333+
rollout_correction: RolloutCorrection = field(default_factory=RolloutCorrection)
299334
# ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
300335
# and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
301336
# if they are really needed (e.g., for GAE advantage/returns computation)
@@ -349,6 +384,7 @@ class veRLConfig:
349384
custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction)
350385
algorithm: Algorithm = field(default_factory=Algorithm)
351386
trainer: Trainer = field(default_factory=Trainer)
387+
global_profiler: dict = field(default_factory=dict)
352388
synchronizer: Optional[SynchronizerConfig] = None
353389
enable_preview: bool = True
354390

@@ -423,8 +459,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
423459
) # kept to pass RayPPOTrainer._validate_config
424460

425461
self.synchronizer = config.synchronizer
462+
self.actor_rollout_ref.nccl_timeout = config.synchronizer.sync_timeout
426463
self.actor_rollout_ref.synchronizer = config.synchronizer
427464
self.actor_rollout_ref.explorer_name = config.explorer.name
465+
algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type)
466+
self.critic.enable = algorithm.use_critic
467+
self.critic.nccl_timeout = config.synchronizer.sync_timeout
428468
self.critic.ray_namespace = config.synchronizer.ray_namespace
429469

430470
# Actor / Rollout Config
@@ -539,11 +579,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
539579

540580
# LoRA related config
541581
if config.model.lora_configs is not None:
542-
self.actor_rollout_ref.model.lora_rank = config.model.lora_configs[0].lora_rank
543-
self.actor_rollout_ref.model.lora_alpha = config.model.lora_configs[0].lora_alpha
544-
self.actor_rollout_ref.model.target_modules = config.model.lora_configs[
545-
0
546-
].target_modules
582+
lora_config = config.model.lora_configs[0]
583+
actor_model_config = self.actor_rollout_ref.model
584+
for attr in ["lora_rank", "lora_alpha", "target_modules", "exclude_modules"]:
585+
setattr(actor_model_config, attr, getattr(lora_config, attr))
586+
if not lora_config.is_dummy:
587+
actor_model_config.lora_adapter_path = lora_config.path
547588
if self.actor_rollout_ref.actor.strategy not in ["fsdp", "fsdp2"]:
548589
logger.warning(
549590
f"Lora is only supported for fsdp and fsdp2, but got {self.actor_rollout_ref.actor.strategy} instead, changed to fsdp."
@@ -559,6 +600,13 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
559600
for field_name in config.algorithm.optimizer.__dataclass_fields__:
560601
field_value = getattr(config.algorithm.optimizer, field_name)
561602
if field_name == "optimizer_type":
603+
if config.trainer.trainer_strategy.startswith("fsdp"):
604+
optim_map = {
605+
"adam": "AdamW",
606+
"adamw": "AdamW",
607+
"sgd": "SGD",
608+
}
609+
field_value = optim_map.get(field_value, field_value)
562610
setattr(self.actor_rollout_ref.actor.optim, "optimizer", field_value)
563611
elif hasattr(self.actor_rollout_ref.actor.optim, field_name):
564612
setattr(self.actor_rollout_ref.actor.optim, field_name, field_value)

trinity/trainer/verl/dp_actor.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
"""
1717
Single Process Actor.
18-
Modified from https://github.com/volcengine/verl/blob/v0.5.0/verl/workers/actor/dp_actor.py
18+
Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/dp_actor.py
1919
"""
2020

2121
import logging
@@ -67,9 +67,8 @@ def update_policy(self, data: DataProto): # noqa: C901
6767
# make sure we are in training mode
6868
self.actor_module.train()
6969

70-
temperature = data.meta_info[
71-
"temperature"
72-
] # temperature must be in the data.meta_info to avoid silent error
70+
# temperature must be in the data.meta_info to avoid silent error
71+
temperature = data.meta_info["temperature"]
7372
select_keys = [
7473
"input_ids",
7574
"position_ids",
@@ -80,13 +79,17 @@ def update_policy(self, data: DataProto): # noqa: C901
8079
select_keys.extend(self.policy_loss_fn.select_keys)
8180
if not isinstance(self.kl_loss_fn, DummyKLFn):
8281
select_keys.append("ref_log_prob")
82+
# rollout_is_weights will be used in policy loss
83+
# rollout_log_probs is equal to old_log_prob in Trinity
8384
select_keys = list(set(select_keys))
8485

8586
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
8687
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
8788

8889
data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)
8990

91+
# Split to make minibatch iterator for updating the actor
92+
# See PPO paper for details. https://arxiv.org/abs/1707.06347
9093
mini_batches = data.split(self.config.ppo_mini_batch_size)
9194

9295
# EXPERIMENTAL: apply loss scale fix
@@ -119,12 +122,11 @@ def update_policy(self, data: DataProto): # noqa: C901
119122
self.actor_optimizer.zero_grad()
120123

121124
for micro_batch in micro_batches:
125+
micro_batch = micro_batch.to(get_device_id())
122126
micro_batch_metrics = {}
123-
model_inputs = {
124-
**micro_batch.batch.to(get_device_id()),
125-
**micro_batch.non_tensor_batch,
126-
}
127+
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
127128
response_mask = model_inputs["response_mask"]
129+
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
128130

129131
# all return: (bsz, response_length)
130132
calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
@@ -141,6 +143,23 @@ def update_policy(self, data: DataProto): # noqa: C901
141143
src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics
142144
)
143145

146+
# TODO: to be check
147+
# Skip if using bypass_mode loss (metrics already computed in pg_metrics)
148+
rollout_log_prob = model_inputs.get("rollout_log_probs", None)
149+
if loss_mode != "bypass_mode" and rollout_log_prob is not None:
150+
# Compute metrics using CURRENT policy π_θ vs π_rollout
151+
# Tracks evolving off-policy gap as π_θ updates during mini-batch training
152+
from verl.trainer.ppo.rollout_corr_helper import (
153+
compute_rollout_corr_metrics_from_logprobs,
154+
)
155+
156+
rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs(
157+
log_prob=log_prob,
158+
rollout_log_prob=rollout_log_prob,
159+
response_mask=response_mask,
160+
)
161+
micro_batch_metrics.update(rollout_corr_metrics)
162+
144163
# compute entropy loss from entropy
145164
entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore
146165
entropy=entropy,
@@ -185,7 +204,15 @@ def update_policy(self, data: DataProto): # noqa: C901
185204

186205
loss = policy_loss * loss_scale
187206
micro_batch_metrics["actor/final_loss"] = loss.detach().item()
188-
loss.backward()
207+
if "actor/kl_loss" in micro_batch_metrics:
208+
micro_batch_metrics["actor/kl_loss"] *= loss_scale
209+
if "actor/pg_loss" in micro_batch_metrics:
210+
micro_batch_metrics["actor/pg_loss"] *= loss_scale
211+
212+
if self.scaler is not None:
213+
self.scaler.scale(loss).backward()
214+
else:
215+
loss.backward()
189216

190217
append_to_dict(metrics, micro_batch_metrics)
191218

0 commit comments

Comments
 (0)