44from typing import Any , Dict , List , Optional
55
66from omegaconf import OmegaConf
7+ from verl .workers .config import PolicyLossConfig , RouterReplayConfig
78
9+ from trinity .algorithm import ALGORITHM_TYPE
810from trinity .common .config import Config , SynchronizerConfig , set_if_none
911from trinity .common .constants import EXPLORER_NAME
1012from trinity .utils .log import get_logger
@@ -41,6 +43,8 @@ class ActorModel:
4143 lora_rank : int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
4244 lora_alpha : int = 32
4345 target_modules : Optional [str ] = "all-linear"
46+ exclude_modules : Optional [str ] = None
47+ lora_adapter_path : Optional [str ] = None
4448
4549 # rope configs
4650 rope_scaling : Optional [dict ] = None
@@ -51,14 +55,15 @@ class ActorModel:
5155class Optim :
5256 # For actor, most fields are set in algorithm.optimizer
5357 # For critic, you can set trainer_config.critic.optim
58+ optimizer : str = "adam"
59+ optimizer_impl : str = "torch.optim"
5460 lr : float = 1e-6
5561 lr_warmup_steps : int = - 1
5662 lr_warmup_steps_ratio : float = 0.0
5763 min_lr_ratio : Optional [float ] = 0.0
58- warmup_style : str = "constant"
64+ lr_scheduler_type : str = "constant"
5965 total_training_steps : int = - 1 # ! DO NOT SET, use trainer.total_steps
6066 betas : List [float ] = field (default_factory = lambda : [0.9 , 0.999 ])
61- optimizer : str = "adam"
6267 clip_grad : float = 1.0
6368 lr_warmup_init : float = 0.0
6469 lr_decay_steps : Optional [int ] = None
@@ -69,6 +74,7 @@ class Optim:
6974 lr_wsd_decay_style : str = "exponential"
7075 lr_wsd_decay_steps : Optional [int ] = None
7176 use_checkpoint_opt_param_scheduler : bool = False
77+ override_optimizer_config : Optional [dict ] = None
7278
7379
7480@dataclass
@@ -78,6 +84,7 @@ class WrapPolicy:
7884
7985@dataclass
8086class FSDPConfig :
87+ _target_ : str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET
8188 param_offload : bool = False
8289 optimizer_offload : bool = False
8390 offload_policy : bool = False
@@ -92,15 +99,15 @@ class FSDPConfig:
9299class Checkpoint :
93100 load_contents : List [str ] = field (default_factory = lambda : ["model" , "optimizer" , "extra" ])
94101 save_contents : List [str ] = field (default_factory = lambda : ["model" , "optimizer" , "extra" ])
95- async_save : bool = False # do not set, async save has bug in verl megatron training
102+ async_save : bool = False # TODO: testing async save
96103
97104
98105@dataclass
99106class OverrideTransformerConfig :
100- recompute_granularity : Optional [str ] = None
107+ recompute_granularity : Optional [str ] = "full"
101108 recompute_modules : List [str ] = field (default_factory = lambda : ["core_attn" ])
102- recompute_method : Optional [str ] = None
103- recompute_num_layers : Optional [int ] = None
109+ recompute_method : Optional [str ] = "uniform"
110+ recompute_num_layers : Optional [int ] = 1
104111
105112
106113@dataclass
@@ -124,6 +131,8 @@ class MegatronConfig:
124131 default_factory = OverrideTransformerConfig
125132 )
126133 use_mbridge : bool = False
134+ dtype : str = "bfloat16"
135+ use_remove_padding : bool = True
127136
128137
129138@dataclass
@@ -157,6 +166,9 @@ class Actor:
157166 profile : ProfileConfig = field (default_factory = ProfileConfig )
158167 data_loader_seed : Optional [int ] = None
159168 load_weight : bool = True
169+ policy_loss : PolicyLossConfig = field (default_factory = PolicyLossConfig )
170+ profiler : dict = field (default_factory = dict )
171+ router_replay : RouterReplayConfig = field (default_factory = RouterReplayConfig )
160172 # do not set
161173 loss_agg_mode : str = "token-mean"
162174 clip_ratio : float = 0.2
@@ -182,6 +194,8 @@ class Ref:
182194 megatron : MegatronConfig = field (default_factory = MegatronConfig )
183195 profile : ProfileConfig = field (default_factory = ProfileConfig )
184196 load_weight : bool = True
197+ profiler : dict = field (default_factory = dict )
198+ router_replay : RouterReplayConfig = field (default_factory = RouterReplayConfig )
185199
186200
187201@dataclass
@@ -214,6 +228,7 @@ class ActorRolloutRef:
214228 actor : Actor = field (default_factory = Actor )
215229 ref : Ref = field (default_factory = Ref )
216230 rollout : Rollout = field (default_factory = Rollout )
231+ nccl_timeout : float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
217232 synchronizer : Optional [SynchronizerConfig ] = None
218233 explorer_name : str = EXPLORER_NAME
219234
@@ -229,9 +244,14 @@ class CriticModel:
229244 use_remove_padding : bool = True
230245 fsdp_config : FSDPConfig = field (default_factory = FSDPConfig )
231246
247+ # rope configs
248+ rope_scaling : Optional [dict ] = None
249+ rope_theta : Optional [float ] = None
250+
232251
233252@dataclass
234253class Critic :
254+ enable : bool = False
235255 strategy : Optional [str ] = None
236256 optim : Optim = field (default_factory = Optim )
237257 model : CriticModel = field (default_factory = CriticModel )
@@ -255,7 +275,9 @@ class Critic:
255275 profile : ProfileConfig = field (default_factory = ProfileConfig )
256276 data_loader_seed : Optional [int ] = None
257277 load_weight : bool = True
278+ nccl_timeout : float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
258279 ray_namespace : str = "" # automatically generated
280+ profiler : dict = field (default_factory = dict )
259281
260282
261283@dataclass
@@ -278,6 +300,7 @@ class RewardModel:
278300 use_dynamic_bsz : bool = False
279301 forward_max_token_len_per_gpu : int = 0
280302 reward_manager : str = "naive"
303+ use_reward_loop : bool = True
281304
282305
283306@dataclass
@@ -294,8 +317,24 @@ class KL_Ctrl:
294317 target_kl : float = 0.1
295318
296319
320+ @dataclass
321+ class RolloutCorrection :
322+ rollout_is : Optional [str ] = None
323+ rollout_is_threshold : float = 2.0
324+ rollout_rs : Optional [str ] = None
325+ rollout_rs_threshold : Optional [float ] = None
326+ rollout_rs_threshold_lower : Optional [float ] = None
327+ rollout_token_veto_threshold : Optional [float ] = None
328+ # Because rollout and training in Trinity runs separately,
329+ # rollout_is_batch_normalize is default to True
330+ bypass_mode : bool = True
331+ loss_type : str = "ppo_clip"
332+ rollout_is_batch_normalize : bool = False
333+
334+
297335@dataclass
298336class Algorithm :
337+ rollout_correction : RolloutCorrection = field (default_factory = RolloutCorrection )
299338 # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
300339 # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
301340 # if they are really needed (e.g., for GAE advantage/returns computation)
@@ -349,6 +388,7 @@ class veRLConfig:
349388 custom_reward_function : CustomRewardFunction = field (default_factory = CustomRewardFunction )
350389 algorithm : Algorithm = field (default_factory = Algorithm )
351390 trainer : Trainer = field (default_factory = Trainer )
391+ global_profiler : dict = field (default_factory = dict )
352392 synchronizer : Optional [SynchronizerConfig ] = None
353393 enable_preview : bool = True
354394
@@ -426,8 +466,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
426466 ) # kept to pass RayPPOTrainer._validate_config
427467
428468 self .synchronizer = config .synchronizer
469+ self .actor_rollout_ref .nccl_timeout = config .synchronizer .sync_timeout
429470 self .actor_rollout_ref .synchronizer = config .synchronizer
430471 self .actor_rollout_ref .explorer_name = config .explorer .name
472+ algorithm = ALGORITHM_TYPE .get (config .algorithm .algorithm_type )
473+ self .critic .enable = algorithm .use_critic
474+ self .critic .nccl_timeout = config .synchronizer .sync_timeout
431475 self .critic .ray_namespace = config .synchronizer .ray_namespace
432476
433477 # Actor / Rollout Config
@@ -507,6 +551,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
507551 set_if_none (self .critic , "strategy" , config .trainer .trainer_strategy )
508552 self .critic .model .path = config .model .critic_model_path
509553 self .critic .model .tokenizer_path = config .model .critic_model_path
554+ self .critic .model .rope_scaling = config .model .rope_scaling
555+ self .critic .model .rope_theta = config .model .rope_theta
510556 self .critic .ppo_mini_batch_size = config .buffer .train_batch_size
511557 self .critic .rollout_n = self .actor_rollout_ref .rollout .n
512558 self .critic .optim .total_training_steps = self .trainer .total_training_steps
@@ -542,11 +588,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
542588
543589 # LoRA related config
544590 if config .model .lora_configs is not None :
545- self .actor_rollout_ref .model .lora_rank = config .model .lora_configs [0 ].lora_rank
546- self .actor_rollout_ref .model .lora_alpha = config .model .lora_configs [0 ].lora_alpha
547- self .actor_rollout_ref .model .target_modules = config .model .lora_configs [
548- 0
549- ].target_modules
591+ lora_config = config .model .lora_configs [0 ]
592+ actor_model_config = self .actor_rollout_ref .model
593+ for attr in ["lora_rank" , "lora_alpha" , "target_modules" , "exclude_modules" ]:
594+ setattr (actor_model_config , attr , getattr (lora_config , attr ))
595+ if not lora_config .is_dummy :
596+ actor_model_config .lora_adapter_path = lora_config .path
550597 if self .actor_rollout_ref .actor .strategy not in ["fsdp" , "fsdp2" ]:
551598 logger .warning (
552599 f"Lora is only supported for fsdp and fsdp2, but got { self .actor_rollout_ref .actor .strategy } instead, changed to fsdp."
@@ -565,6 +612,17 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
565612 setattr (self .actor_rollout_ref .actor .optim , "optimizer" , field_value )
566613 elif hasattr (self .actor_rollout_ref .actor .optim , field_name ):
567614 setattr (self .actor_rollout_ref .actor .optim , field_name , field_value )
615+ # fix optimizer type for fsdp
616+ if config .trainer .trainer_strategy .startswith ("fsdp" ):
617+ optim_map = {
618+ "adam" : "AdamW" ,
619+ "adamw" : "AdamW" ,
620+ "sgd" : "SGD" ,
621+ }
622+ actor_optim = self .actor_rollout_ref .actor .optim
623+ actor_optim .optimizer = optim_map .get (actor_optim .optimizer , actor_optim .optimizer )
624+ critic_optim = self .critic .optim
625+ critic_optim .optimizer = optim_map .get (critic_optim .optimizer , critic_optim .optimizer )
568626 self .actor_rollout_ref .actor .use_kl_loss = config .algorithm .kl_loss_fn != "none"
569627 self .algorithm .use_kl_in_reward = config .algorithm .kl_penalty_fn != "none"
570628 # TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to
0 commit comments