44from typing import Any , Dict , List , Optional
55
66from omegaconf import OmegaConf
7+ from verl .workers .config import PolicyLossConfig , RouterReplayConfig
78
9+ from trinity .algorithm import ALGORITHM_TYPE
810from trinity .common .config import Config , SynchronizerConfig , set_if_none
911from trinity .common .constants import EXPLORER_NAME
1012from trinity .utils .log import get_logger
@@ -41,6 +43,8 @@ class ActorModel:
4143 lora_rank : int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
4244 lora_alpha : int = 32
4345 target_modules : Optional [str ] = "all-linear"
46+ exclude_modules : Optional [str ] = None
47+ lora_adapter_path : Optional [str ] = None
4448
4549 # rope configs
4650 rope_scaling : Optional [dict ] = None
@@ -51,14 +55,15 @@ class ActorModel:
5155class Optim :
5256 # For actor, most fields are set in algorithm.optimizer
5357 # For critic, you can set trainer_config.critic.optim
58+ optimizer : str = "AdamW"
59+ optimizer_impl : str = "torch.optim"
5460 lr : float = 1e-6
5561 lr_warmup_steps : int = - 1
5662 lr_warmup_steps_ratio : float = 0.0
5763 min_lr_ratio : Optional [float ] = 0.0
5864 warmup_style : str = "constant"
5965 total_training_steps : int = - 1 # ! DO NOT SET, use trainer.total_steps
6066 betas : List [float ] = field (default_factory = lambda : [0.9 , 0.999 ])
61- optimizer : str = "adam"
6267 clip_grad : float = 1.0
6368 lr_warmup_init : float = 0.0
6469 lr_decay_steps : Optional [int ] = None
@@ -69,6 +74,7 @@ class Optim:
6974 lr_wsd_decay_style : str = "exponential"
7075 lr_wsd_decay_steps : Optional [int ] = None
7176 use_checkpoint_opt_param_scheduler : bool = False
77+ override_optimizer_config : Optional [dict ] = None
7278
7379
7480@dataclass
@@ -78,6 +84,7 @@ class WrapPolicy:
7884
7985@dataclass
8086class FSDPConfig :
87+ _target_ : str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET
8188 param_offload : bool = False
8289 optimizer_offload : bool = False
8390 offload_policy : bool = False
@@ -92,7 +99,7 @@ class FSDPConfig:
9299class Checkpoint :
93100 load_contents : List [str ] = field (default_factory = lambda : ["model" , "optimizer" , "extra" ])
94101 save_contents : List [str ] = field (default_factory = lambda : ["model" , "optimizer" , "extra" ])
95- async_save : bool = False # do not set, async save has bug in verl megatron training
102+ async_save : bool = False # TODO: testing async save
96103
97104
98105@dataclass
@@ -124,6 +131,8 @@ class MegatronConfig:
124131 default_factory = OverrideTransformerConfig
125132 )
126133 use_mbridge : bool = False
134+ dtype : str = "bfloat16"
135+ use_remove_padding : bool = True
127136
128137
129138@dataclass
@@ -157,6 +166,9 @@ class Actor:
157166 profile : ProfileConfig = field (default_factory = ProfileConfig )
158167 data_loader_seed : Optional [int ] = None
159168 load_weight : bool = True
169+ policy_loss : PolicyLossConfig = field (default_factory = PolicyLossConfig )
170+ profiler : dict = field (default_factory = dict )
171+ router_replay : RouterReplayConfig = field (default_factory = RouterReplayConfig )
160172 # do not set
161173 loss_agg_mode : str = "token-mean"
162174 clip_ratio : float = 0.2
@@ -182,6 +194,8 @@ class Ref:
182194 megatron : MegatronConfig = field (default_factory = MegatronConfig )
183195 profile : ProfileConfig = field (default_factory = ProfileConfig )
184196 load_weight : bool = True
197+ profiler : dict = field (default_factory = dict )
198+ router_replay : RouterReplayConfig = field (default_factory = RouterReplayConfig )
185199
186200
187201@dataclass
@@ -214,6 +228,7 @@ class ActorRolloutRef:
214228 actor : Actor = field (default_factory = Actor )
215229 ref : Ref = field (default_factory = Ref )
216230 rollout : Rollout = field (default_factory = Rollout )
231+ nccl_timeout : float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
217232 synchronizer : Optional [SynchronizerConfig ] = None
218233 explorer_name : str = EXPLORER_NAME
219234
@@ -232,6 +247,7 @@ class CriticModel:
232247
233248@dataclass
234249class Critic :
250+ enable : bool = False
235251 strategy : Optional [str ] = None
236252 optim : Optim = field (default_factory = Optim )
237253 model : CriticModel = field (default_factory = CriticModel )
@@ -255,7 +271,9 @@ class Critic:
255271 profile : ProfileConfig = field (default_factory = ProfileConfig )
256272 data_loader_seed : Optional [int ] = None
257273 load_weight : bool = True
274+ nccl_timeout : float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
258275 ray_namespace : str = "" # automatically generated
276+ profiler : dict = field (default_factory = dict )
259277
260278
261279@dataclass
@@ -278,6 +296,7 @@ class RewardModel:
278296 use_dynamic_bsz : bool = False
279297 forward_max_token_len_per_gpu : int = 0
280298 reward_manager : str = "naive"
299+ use_reward_loop : bool = True
281300
282301
283302@dataclass
@@ -294,8 +313,24 @@ class KL_Ctrl:
294313 target_kl : float = 0.1
295314
296315
316+ @dataclass
317+ class RolloutCorrection :
318+ rollout_is : Optional [str ] = None
319+ rollout_is_threshold : float = 2.0
320+ rollout_rs : Optional [str ] = None
321+ rollout_rs_threshold : Optional [float ] = None
322+ rollout_rs_threshold_lower : Optional [float ] = None
323+ rollout_token_veto_threshold : Optional [float ] = None
324+ # Because rollout and training in Trinity runs separately,
325+ # rollout_is_batch_normalize is default to True
326+ bypass_mode : bool = True
327+ loss_type : str = "ppo_clip"
328+ rollout_is_batch_normalize : bool = False
329+
330+
297331@dataclass
298332class Algorithm :
333+ rollout_correction : RolloutCorrection = field (default_factory = RolloutCorrection )
299334 # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
300335 # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
301336 # if they are really needed (e.g., for GAE advantage/returns computation)
@@ -349,6 +384,7 @@ class veRLConfig:
349384 custom_reward_function : CustomRewardFunction = field (default_factory = CustomRewardFunction )
350385 algorithm : Algorithm = field (default_factory = Algorithm )
351386 trainer : Trainer = field (default_factory = Trainer )
387+ global_profiler : dict = field (default_factory = dict )
352388 synchronizer : Optional [SynchronizerConfig ] = None
353389 enable_preview : bool = True
354390
@@ -423,8 +459,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
423459 ) # kept to pass RayPPOTrainer._validate_config
424460
425461 self .synchronizer = config .synchronizer
462+ self .actor_rollout_ref .nccl_timeout = config .synchronizer .sync_timeout
426463 self .actor_rollout_ref .synchronizer = config .synchronizer
427464 self .actor_rollout_ref .explorer_name = config .explorer .name
465+ algorithm = ALGORITHM_TYPE .get (config .algorithm .algorithm_type )
466+ self .critic .enable = algorithm .use_critic
467+ self .critic .nccl_timeout = config .synchronizer .sync_timeout
428468 self .critic .ray_namespace = config .synchronizer .ray_namespace
429469
430470 # Actor / Rollout Config
@@ -539,11 +579,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
539579
540580 # LoRA related config
541581 if config .model .lora_configs is not None :
542- self .actor_rollout_ref .model .lora_rank = config .model .lora_configs [0 ].lora_rank
543- self .actor_rollout_ref .model .lora_alpha = config .model .lora_configs [0 ].lora_alpha
544- self .actor_rollout_ref .model .target_modules = config .model .lora_configs [
545- 0
546- ].target_modules
582+ lora_config = config .model .lora_configs [0 ]
583+ actor_model_config = self .actor_rollout_ref .model
584+ for attr in ["lora_rank" , "lora_alpha" , "target_modules" , "exclude_modules" ]:
585+ setattr (actor_model_config , attr , getattr (lora_config , attr ))
586+ if not lora_config .is_dummy :
587+ actor_model_config .lora_adapter_path = lora_config .path
547588 if self .actor_rollout_ref .actor .strategy not in ["fsdp" , "fsdp2" ]:
548589 logger .warning (
549590 f"Lora is only supported for fsdp and fsdp2, but got { self .actor_rollout_ref .actor .strategy } instead, changed to fsdp."
@@ -559,6 +600,13 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
559600 for field_name in config .algorithm .optimizer .__dataclass_fields__ :
560601 field_value = getattr (config .algorithm .optimizer , field_name )
561602 if field_name == "optimizer_type" :
603+ if config .trainer .trainer_strategy .startswith ("fsdp" ):
604+ optim_map = {
605+ "adam" : "AdamW" ,
606+ "adamw" : "AdamW" ,
607+ "sgd" : "SGD" ,
608+ }
609+ field_value = optim_map .get (field_value , field_value )
562610 setattr (self .actor_rollout_ref .actor .optim , "optimizer" , field_value )
563611 elif hasattr (self .actor_rollout_ref .actor .optim , field_name ):
564612 setattr (self .actor_rollout_ref .actor .optim , field_name , field_value )
0 commit comments