forked from agentscope-ai/Trinity-RFT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathverl_config.py
More file actions
683 lines (604 loc) · 26.7 KB
/
verl_config.py
File metadata and controls
683 lines (604 loc) · 26.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
import math
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from omegaconf import OmegaConf
from verl.workers.config import PolicyLossConfig, RouterReplayConfig
from trinity.algorithm import ALGORITHM_TYPE
from trinity.common.config import Config, SynchronizerConfig, set_if_none
from trinity.common.constants import EXPLORER_NAME
from trinity.common.patch import kimi_vl_monkey_patch_decorator
from trinity.utils.log import get_logger
logger = get_logger(__name__)
@dataclass
class Data:
train_batch_size: int = 1024 # kept to pass RayPPOTrainer._validate_config
trust_remote_code: bool = False
@dataclass
class FusedKernelOptions:
impl_backend: Optional[str] = None
@dataclass
class ActorModel:
path: str = ""
external_lib: Optional[str] = None
override_config: Dict[str, Any] = field(default_factory=dict)
enable_gradient_checkpointing: bool = True
use_remove_padding: bool = True
use_fused_kernels: bool = False
fused_kernel_options: FusedKernelOptions = field(default_factory=FusedKernelOptions)
custom_chat_template: Optional[str] = None
enable_activation_offload: bool = False
use_shm: bool = False
trust_remote_code: bool = False # Whether to enable loading a remote code model
# lora configs
lora_rank: int = 0 # The rank of the LoRA model, default to 0. If lora_rank > 0, LoRA module is enabled in trainer
lora_alpha: int = 32
target_modules: Optional[str] = "all-linear"
exclude_modules: Optional[str] = None
lora_adapter_path: Optional[str] = None
# rope configs
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None
@dataclass
class Optim:
# For actor, most fields are set in algorithm.optimizer
# For critic, you can set trainer_config.critic.optim
optimizer: str = "adam"
optimizer_impl: str = "torch.optim"
lr: float = 1e-6
lr_warmup_steps: int = -1
lr_warmup_steps_ratio: float = 0.0
min_lr_ratio: Optional[float] = 0.0
warmup_style: Optional[str] = None # deprecated !
lr_scheduler_type: str = "constant"
total_training_steps: int = -1 # ! DO NOT SET, use trainer.total_steps
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
clip_grad: float = 1.0
lr_warmup_init: Optional[float] = None # 0.0
lr_decay_steps: Optional[int] = None
lr_decay_style: Optional[str] = None # "constant"
min_lr: Optional[float] = None # 0.0
weight_decay: float = 0.01
weight_decay_incr_style: str = "constant"
lr_wsd_decay_style: str = "exponential"
lr_wsd_decay_steps: Optional[int] = None
use_checkpoint_opt_param_scheduler: bool = False
override_optimizer_config: Optional[dict] = None
@dataclass
class WrapPolicy:
min_num_params: int = 0
@dataclass
class FSDPConfig:
_target_: str = "verl.workers.config.FSDPEngineConfig" # DO NOT SET
param_offload: bool = False
optimizer_offload: bool = False
offload_policy: bool = False
reshard_after_forward: bool = True
wrap_policy: WrapPolicy = field(default_factory=WrapPolicy)
fsdp_size: int = -1
forward_prefetch: bool = False
model_dtype: Optional[str] = None
dtype: str = "bfloat16"
mixed_precision: dict = field(default_factory=dict)
@dataclass
class Checkpoint:
load_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
save_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"])
async_save: bool = False # TODO: testing async save
@dataclass
class OverrideTransformerConfig:
recompute_granularity: Optional[str] = "full"
recompute_modules: List[str] = field(default_factory=lambda: ["core_attn"])
recompute_method: Optional[str] = "uniform"
recompute_num_layers: Optional[int] = 1
@dataclass
class MegatronConfig:
param_offload: bool = False
grad_offload: bool = False
optimizer_offload: bool = False
tensor_model_parallel_size: int = 1
expert_model_parallel_size: int = 1
expert_tensor_parallel_size: Optional[int] = None
pipeline_model_parallel_size: int = 1
virtual_pipeline_model_parallel_size: Optional[int] = None
context_parallel_size: int = 1
sequence_parallel: bool = True
use_distributed_optimizer: bool = True
use_dist_checkpointing: bool = False
dist_checkpointing_path: Optional[str] = None
seed: int = 42
override_ddp_config: dict = field(default_factory=dict)
override_transformer_config: OverrideTransformerConfig = field(
default_factory=OverrideTransformerConfig
)
use_mbridge: bool = False
dtype: str = "bfloat16"
use_remove_padding: bool = True
@dataclass
class ProfileConfig:
use_profile: bool = False
profile_ranks: Optional[List[int]] = None
step_start: int = -1
step_end: int = -1
save_path: Optional[str] = None
@dataclass
class Actor:
strategy: Optional[str] = None
ppo_mini_batch_size: int = 256
ppo_micro_batch_size: Optional[int] = None
ppo_micro_batch_size_per_gpu: int = 1
use_dynamic_bsz: Optional[bool] = None
ppo_max_token_len_per_gpu: Optional[int] = None
fix_actor_microbatch_loss_scale: Optional[bool] = None # EXPERIMENTAL
grad_clip: Optional[float] = None
ppo_epochs: int = 1
shuffle: bool = False
ulysses_sequence_parallel_size: Optional[int] = None
entropy_from_logits_with_chunking: bool = False
entropy_checkpointing: bool = False
checkpoint: Checkpoint = field(default_factory=Checkpoint)
optim: Optim = field(default_factory=Optim)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
megatron: MegatronConfig = field(default_factory=MegatronConfig)
profile: ProfileConfig = field(default_factory=ProfileConfig)
data_loader_seed: Optional[int] = None
load_weight: bool = True
policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig)
profiler: dict = field(default_factory=dict)
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
# do not set
loss_agg_mode: str = "token-mean"
loss_scale_factor: Optional[float] = None
clip_ratio: float = 0.2
clip_ratio_low: Optional[float] = None
clip_ratio_high: Optional[float] = None
entropy_coeff: float = 0.001
use_kl_loss: bool = False
@dataclass
class Ref:
strategy: Optional[str] = None
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
log_prob_micro_batch_size: Optional[int] = None
log_prob_micro_batch_size_per_gpu: int = 1
log_prob_use_dynamic_bsz: Optional[bool] = None
log_prob_max_token_len_per_gpu: Optional[int] = None
ulysses_sequence_parallel_size: Optional[int] = None
entropy_from_logits_with_chunking: bool = False
entropy_checkpointing: bool = False
checkpoint: Checkpoint = field(
default_factory=lambda: Checkpoint(load_contents=["model"], save_contents=["model"])
)
megatron: MegatronConfig = field(default_factory=MegatronConfig)
profile: ProfileConfig = field(default_factory=ProfileConfig)
load_weight: bool = True
profiler: dict = field(default_factory=dict)
router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig)
@dataclass
class _ValKwargs:
do_sample: bool = False
@dataclass
class _MultiTurn:
enable: bool = False
@dataclass
class Rollout:
# do not set
val_kwargs: _ValKwargs = field(default_factory=_ValKwargs)
multi_turn: _MultiTurn = field(default_factory=_MultiTurn)
temperature: float = 1.0
n: int = 1 # > 1 for grpo
log_prob_use_dynamic_bsz: Optional[bool] = None
log_prob_micro_batch_size: Optional[int] = None
log_prob_micro_batch_size_per_gpu: Optional[int] = None
log_prob_max_token_len_per_gpu: Optional[int] = None
@dataclass
class ActorRolloutRef:
hybrid_engine: bool = True
model: ActorModel = field(default_factory=ActorModel)
actor: Actor = field(default_factory=Actor)
ref: Ref = field(default_factory=Ref)
rollout: Rollout = field(default_factory=Rollout)
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
synchronizer: Optional[SynchronizerConfig] = None
explorer_name: str = EXPLORER_NAME
@dataclass
class CriticModel:
path: str = ""
tokenizer_path: str = ""
override_config: Dict[str, str] = field(default_factory=dict)
external_lib: Optional[str] = None
trust_remote_code: bool = False # Whether to enable loading a remote code model
enable_gradient_checkpointing: bool = True
use_remove_padding: bool = True
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
# rope configs
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None
@dataclass
class Critic:
enable: bool = False
strategy: Optional[str] = None
optim: Optim = field(default_factory=Optim)
model: CriticModel = field(default_factory=CriticModel)
ppo_mini_batch_size: int = 0
ppo_micro_batch_size: Optional[int] = None
ppo_micro_batch_size_per_gpu: int = 1
forward_micro_batch_size: Optional[int] = None
forward_micro_batch_size_per_gpu: Optional[int] = None
use_dynamic_bsz: Optional[bool] = None
ppo_max_token_len_per_gpu: Optional[int] = None
forward_max_token_len_per_gpu: Optional[int] = None
ulysses_sequence_parallel_size: Optional[int] = None
ppo_epochs: int = 1
shuffle: bool = False
grad_clip: Optional[float] = None
cliprange_value: float = 0.0
checkpoint: Checkpoint = field(default_factory=Checkpoint)
rollout_n: int = 1
loss_agg_mode: str = "token-mean"
megatron: MegatronConfig = field(default_factory=MegatronConfig)
profile: ProfileConfig = field(default_factory=ProfileConfig)
data_loader_seed: Optional[int] = None
load_weight: bool = True
nccl_timeout: float = 600 # ! DO NOT SET, it will be set by `config.synchronizer.sync_timeout`
ray_namespace: str = "" # automatically generated
profiler: dict = field(default_factory=dict)
@dataclass
class _RewardModel:
input_tokenizer: Optional[str] = None
path: str = ""
external_lib: Optional[str] = None
use_remove_padding: bool = False
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
@dataclass
class RewardModel:
enable: bool = False
strategy: Optional[str] = None
model: _RewardModel = field(default_factory=_RewardModel)
micro_batch_size_per_gpu: int = 1
max_length: Optional[int] = None
ulysses_sequence_parallel_size: int = 1
use_dynamic_bsz: bool = False
forward_max_token_len_per_gpu: int = 0
reward_manager: str = "naive"
use_reward_loop: bool = True
@dataclass
class CustomRewardFunction:
path: Optional[str] = None
name: str = "compute_score"
@dataclass
class KL_Ctrl:
type: str = "fixed"
kl_coef: float = 0.001
horizon: float = 10000
target_kl: float = 0.1
@dataclass
class RolloutCorrection:
rollout_is: Optional[str] = None
rollout_is_threshold: float = 2.0
rollout_rs: Optional[str] = None
rollout_rs_threshold: Optional[float] = None
rollout_rs_threshold_lower: Optional[float] = None
rollout_token_veto_threshold: Optional[float] = None
# Because rollout and training in Trinity runs separately,
# rollout_is_batch_normalize is default to True
bypass_mode: bool = True
loss_type: str = "ppo_clip"
rollout_is_batch_normalize: bool = False
@dataclass
class Algorithm:
rollout_correction: RolloutCorrection = field(default_factory=RolloutCorrection)
# ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl,
# and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args
# if they are really needed (e.g., for GAE advantage/returns computation)
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
norm_adv_by_std_in_grpo: bool = True
use_kl_in_reward: bool = False
kl_penalty: str = "kl"
kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl)
@dataclass
class Trainer:
balance_batch: bool = True
total_epochs: int = 30
total_training_steps: Optional[
int
] = None # ! DO NOT SET, use trainer.total_steps in global_config
project_name: str = ""
group_name: str = ""
experiment_name: str = ""
logger: List[str] = field(default_factory=list)
val_generations_to_log_to_wandb: int = 0
nnodes: int = 0
n_gpus_per_node: int = 0
save_freq: int = 0
resume_mode: str = "auto"
resume_from_path: str = ""
test_freq: int = 0
critic_warmup: int = 0
default_hdfs_dir: Optional[str] = None
remove_previous_ckpt_in_save: bool = False # deprecated
del_local_ckpt_after_load: bool = False
default_local_dir: str = ""
val_before_train: bool = False
training_rollout_mode: str = "parallel"
enable_exp_buffer: bool = True
sync_freq: int = 0
max_actor_ckpt_to_keep: Optional[int] = None
max_critic_ckpt_to_keep: Optional[int] = None
device: str = "cuda" # default to cuda
@dataclass
class veRLConfig:
data: Data = field(default_factory=Data)
actor_rollout_ref: ActorRolloutRef = field(default_factory=ActorRolloutRef)
critic: Critic = field(default_factory=Critic)
reward_model: RewardModel = field(default_factory=RewardModel)
custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction)
algorithm: Algorithm = field(default_factory=Algorithm)
trainer: Trainer = field(default_factory=Trainer)
global_profiler: dict = field(default_factory=dict)
synchronizer: Optional[SynchronizerConfig] = None
enable_preview: bool = True
@kimi_vl_monkey_patch_decorator
def _check_parallel_config(
self,
obj: Union[Actor, Ref, Critic],
component_name: str,
model_config: Union[ActorModel, CriticModel],
fsdp_config: FSDPConfig,
world_size: int,
sp_attr: str = "ulysses_sequence_parallel_size",
) -> None:
if obj.strategy.startswith("fsdp"):
# check sequence parallelism
sp_size = getattr(obj, sp_attr, 1)
if world_size % sp_size != 0:
raise ValueError(
f"The number of trainer GPUs ({world_size}) must be "
f"divisible by `{component_name}.{sp_attr}` ({sp_size}). "
f"Please change `trainer.{sp_attr}` or "
f"`{component_name}.{sp_attr}` in verl config to a reasonable value."
)
try:
import transformers
hf_config = transformers.AutoConfig.from_pretrained(
model_config.path, trust_remote_code=model_config.trust_remote_code
)
num_attention_heads = hf_config.num_attention_heads
except Exception:
num_attention_heads = None
if num_attention_heads and num_attention_heads % sp_size != 0:
raise ValueError(
f"The number of attention heads ({num_attention_heads}) must be "
f"divisible by `trainer.ulysses_sequence_parallel_size` ({sp_size})."
f"Please change `trainer.{sp_attr}` or "
f"`{component_name}.{sp_attr}` in verl config to a reasonable value."
)
fsdp_size = fsdp_config.fsdp_size
if fsdp_size <= 0 or fsdp_size >= world_size:
fsdp_size = fsdp_config.fsdp_size = world_size
if world_size % fsdp_size != 0:
raise ValueError(
f"The number of GPUs ({world_size}) must be "
f"divisible by `{component_name}.fsdp_config.fsdp_size` ({fsdp_size}). "
f"Please change `{component_name}.fsdp_config.fsdp_size` in verl config "
f"to a reasonable value."
)
else:
# TODO: add check for megatron strategy
pass
def _adjust_token_len_if_needed(
self,
obj,
config: Config,
component_name: str,
token_len_attr: str = "ppo_max_token_len_per_gpu",
sp_attr: str = "ulysses_sequence_parallel_size",
) -> None:
"""
Helper to adjust token length per GPU if current setting is too small.
Ensures: token_len * seq_parallel >= config.model.max_model_len
"""
current_token_len = getattr(obj, token_len_attr)
seq_parallel = getattr(obj, sp_attr)
required_min = config.model.max_model_len # type: ignore
if current_token_len * seq_parallel < required_min:
new_token_len = math.ceil(required_min / seq_parallel)
setattr(obj, token_len_attr, new_token_len)
logger.warning(
f"{component_name}.{token_len_attr} is automatically set to {new_token_len} "
f"to match model.max_model_len ({config.model.max_model_len}). If you face OOM issues, "
"please set `model.max_model_len` to a smaller value."
)
def synchronize_config(self, config: Config) -> None: # noqa: C901
"""Synchronize config."""
# Trainer Config
self.trainer.nnodes = config.cluster.trainer_node_num
self.trainer.n_gpus_per_node = config.cluster.trainer_gpu_num_per_node
self.trainer.total_training_steps = config.trainer.total_steps or sys.maxsize
self.trainer.sync_freq = config.synchronizer.sync_interval
self.trainer.save_freq = config.trainer.save_interval
self.trainer.project_name = config.project
self.trainer.group_name = config.group
self.trainer.experiment_name = config.name
self.trainer.default_local_dir = config.checkpoint_job_dir
if config.trainer.max_checkpoints_to_keep is not None:
self.trainer.max_actor_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
self.trainer.max_critic_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
if not config.continue_from_checkpoint:
self.trainer.resume_mode = "disable"
else:
self.trainer.resume_mode = "auto"
# kept to pass RayPPOTrainer._validate_config
self.data.train_batch_size = config.buffer.train_batch_size
self.data.trust_remote_code = config.model.trust_remote_code
self.synchronizer = config.synchronizer
self.actor_rollout_ref.nccl_timeout = config.synchronizer.sync_timeout
self.actor_rollout_ref.synchronizer = config.synchronizer
self.actor_rollout_ref.explorer_name = config.explorer.name
algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type)
self.critic.enable = algorithm.use_critic
self.critic.nccl_timeout = config.synchronizer.sync_timeout
self.critic.ray_namespace = config.synchronizer.ray_namespace
# Actor / Rollout Config
actor_config = self.actor_rollout_ref.actor
rollout_config = self.actor_rollout_ref.rollout
actor_model_config = self.actor_rollout_ref.model
actor_optim = actor_config.optim
actor_model_config.path = config.model.model_path
for attr in ["trust_remote_code", "custom_chat_template", "rope_scaling", "rope_theta"]:
setattr(actor_model_config, attr, getattr(config.model, attr))
actor_optim.total_training_steps = self.trainer.total_training_steps
actor_config.ppo_mini_batch_size = config.buffer.train_batch_size
rollout_config.temperature = (
config.buffer.explorer_input.tasksets[0].rollout_args.temperature
if config.buffer.explorer_input.tasksets
else 1.0
)
rollout_config.n = config.algorithm.repeat_times
for actor_attr, trainer_attr in [
("grad_clip",) * 2,
("use_dynamic_bsz",) * 2,
("fix_actor_microbatch_loss_scale",) * 2,
("ulysses_sequence_parallel_size",) * 2,
("ppo_max_token_len_per_gpu", "max_token_len_per_gpu"),
("strategy", "trainer_strategy"),
]:
set_if_none(actor_config, actor_attr, getattr(config.trainer, trainer_attr))
self._check_parallel_config(
obj=actor_config,
component_name="actor",
model_config=actor_model_config,
fsdp_config=actor_config.fsdp_config,
world_size=config.cluster.trainer_gpu_num,
)
self._adjust_token_len_if_needed(
obj=self.actor_rollout_ref.actor,
config=config,
component_name="actor",
)
# Ref Config
ref_config = self.actor_rollout_ref.ref
for ref_attr, trainer_attr in [
("log_prob_use_dynamic_bsz", "use_dynamic_bsz"),
("log_prob_max_token_len_per_gpu", "max_token_len_per_gpu"),
("ulysses_sequence_parallel_size",) * 2,
("strategy", "trainer_strategy"),
]:
set_if_none(ref_config, ref_attr, getattr(config.trainer, trainer_attr))
self._check_parallel_config(
obj=ref_config,
component_name="ref",
model_config=actor_model_config,
fsdp_config=ref_config.fsdp_config,
world_size=config.cluster.trainer_gpu_num,
)
self._adjust_token_len_if_needed(
obj=self.actor_rollout_ref.ref,
config=config,
component_name="ref",
token_len_attr="log_prob_max_token_len_per_gpu",
)
# Critic config
critic_optim = self.critic.optim
self.critic.model.path = config.model.critic_model_path
self.critic.model.tokenizer_path = config.model.critic_model_path
self.critic.model.rope_scaling = config.model.rope_scaling
self.critic.model.rope_theta = config.model.rope_theta
self.critic.ppo_mini_batch_size = config.buffer.train_batch_size
self.critic.rollout_n = config.algorithm.repeat_times
critic_optim.total_training_steps = self.trainer.total_training_steps
for critic_attr, trainer_attr in [
("grad_clip",) * 2,
("use_dynamic_bsz",) * 2,
("ulysses_sequence_parallel_size",) * 2,
("strategy", "trainer_strategy"),
("ppo_max_token_len_per_gpu", "max_token_len_per_gpu"),
]:
set_if_none(self.critic, critic_attr, getattr(config.trainer, trainer_attr))
self._check_parallel_config(
obj=self.critic,
component_name="critic",
model_config=self.critic.model,
fsdp_config=self.critic.model.fsdp_config,
world_size=config.cluster.trainer_gpu_num,
)
self._adjust_token_len_if_needed(
obj=self.critic,
config=config,
component_name="critic",
)
set_if_none(
self.critic, "forward_max_token_len_per_gpu", self.critic.ppo_max_token_len_per_gpu
)
# LoRA related config
if config.model.lora_configs is not None:
lora_config = config.model.lora_configs[0]
for attr in ["lora_rank", "lora_alpha", "target_modules", "exclude_modules"]:
setattr(actor_model_config, attr, getattr(lora_config, attr))
if not lora_config.is_dummy:
actor_model_config.lora_adapter_path = lora_config.path
if actor_config.strategy not in ["fsdp", "fsdp2"]:
logger.warning(
f"Lora is only supported for fsdp and fsdp2, but got {actor_config.strategy} instead, changed to fsdp."
)
actor_config.strategy = "fsdp"
if self.critic.strategy not in ["fsdp", "fsdp2"]:
logger.warning(
f"Lora is only supported for fsdp and fsdp2, but got {self.critic.strategy} instead, changed to fsdp."
)
self.critic.strategy = "fsdp"
# Algorithm related config
optim_config = config.algorithm.optimizer
for field_name in optim_config.__dataclass_fields__:
field_value = getattr(optim_config, field_name)
if field_name == "optimizer_type":
setattr(actor_optim, "optimizer", field_value)
elif hasattr(actor_optim, field_name):
setattr(actor_optim, field_name, field_value)
# ensure megatron optimizer config compatibility
set_if_none(actor_optim, "lr_warmup_init", optim_config.min_lr_ratio * optim_config.lr)
set_if_none(actor_optim, "lr_decay_steps", self.trainer.total_training_steps)
set_if_none(actor_optim, "lr_decay_style", optim_config.lr_scheduler_type)
set_if_none(actor_optim, "min_lr", optim_config.min_lr_ratio * optim_config.lr)
set_if_none(critic_optim, "lr_warmup_init", 0.0)
set_if_none(critic_optim, "lr_decay_steps", self.trainer.total_training_steps)
set_if_none(critic_optim, "lr_decay_style", "constant")
set_if_none(critic_optim, "min_lr", 0.0)
# fix optimizer type for fsdp
if config.trainer.trainer_strategy.startswith("fsdp"):
optim_map = {
"adam": "AdamW",
"adamw": "AdamW",
"sgd": "SGD",
}
actor_optim.optimizer = optim_map.get(actor_optim.optimizer, actor_optim.optimizer)
critic_optim.optimizer = optim_map.get(critic_optim.optimizer, critic_optim.optimizer)
actor_config.use_kl_loss = config.algorithm.kl_loss_fn != "none"
self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none"
# TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to
# True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper).
# Need to double check whether this is indeed the case,
# and see if adv_estimator can be removed completely.
if config.algorithm.algorithm_type == "dpo": # for DPO
logger.warning("DPO micro batch size is doubled for computing loss.")
actor_config.ppo_micro_batch_size_per_gpu *= 2
ref_config.log_prob_micro_batch_size_per_gpu *= 2
# check rollout config (only works for lora)
for rollout_attr, actor_attr in [
("log_prob_use_dynamic_bsz", "use_dynamic_bsz"),
("log_prob_micro_batch_size", "ppo_micro_batch_size"),
("log_prob_micro_batch_size_per_gpu", "ppo_micro_batch_size_per_gpu"),
("log_prob_max_token_len_per_gpu", "ppo_max_token_len_per_gpu"),
]:
set_if_none(rollout_config, rollout_attr, getattr(actor_config, actor_attr))
# TODO: check other fields
self.enable_preview = config.trainer.enable_preview
def load_config(config_path: str) -> veRLConfig:
schema = OmegaConf.structured(veRLConfig)
yaml_config = OmegaConf.load(config_path)
try:
config = OmegaConf.merge(schema, yaml_config)
return OmegaConf.to_object(config)
except Exception as e:
raise ValueError(f"Invalid configuration: {e}") from e