81
81
)
82
82
from ..utils .infer_utils import infer_guard
83
83
from ..utils .offload_utils import reload_and_offload_scope , reload_tensor_to_gpu
84
+ from ..utils .reshard_utils import ReshardController
84
85
from ..utils .timer_utils import TimerScope , TimerScopeManualLabel
85
86
from .actor_trainer import ActorReferenceTrainer
86
87
from .critic_trainer import CriticTrainer
@@ -232,6 +233,7 @@ def __init__(
232
233
optimizers : Tuple [paddle .optimizer .Optimizer , paddle .optimizer .lr .LRScheduler ] = (None , None ),
233
234
preprocess_logits_for_metrics : Optional [Callable [[paddle .Tensor , paddle .Tensor ], paddle .Tensor ]] = None ,
234
235
generation_config : Optional [GenerationConfig ] = None ,
236
+ reshard_controller : Optional [ReshardController ] = None ,
235
237
):
236
238
"""
237
239
Args:
@@ -282,6 +284,7 @@ def __init__(
282
284
preprocess_logits_for_metrics ,
283
285
)
284
286
287
+ self .reshard_controller = reshard_controller
285
288
trainer_agrs = {
286
289
# "model": None,
287
290
"criterion" : criterion ,
@@ -300,6 +303,7 @@ def __init__(
300
303
model = actor_model ,
301
304
model_eval = actor_model_eval ,
302
305
tokenizer = actor_tokenizer ,
306
+ reshard_controller = reshard_controller ,
303
307
** trainer_agrs ,
304
308
)
305
309
@@ -379,6 +383,7 @@ def create_actor_trainer(
379
383
callbacks : Optional [List [TrainerCallback ]] = None ,
380
384
optimizers : Tuple [paddle .optimizer .Optimizer , paddle .optimizer .lr .LRScheduler ] = (None , None ),
381
385
preprocess_logits_for_metrics : Optional [Callable [[paddle .Tensor , paddle .Tensor ], paddle .Tensor ]] = None ,
386
+ reshard_controller : Optional [ReshardController ] = None ,
382
387
):
383
388
policy_training_args = copy .deepcopy (args )
384
389
lr_scheduler = self .get_scheduler (policy_training_args )
@@ -394,6 +399,7 @@ def create_actor_trainer(
394
399
callbacks ,
395
400
[None , lr_scheduler ],
396
401
preprocess_logits_for_metrics ,
402
+ reshard_controller ,
397
403
)
398
404
actor_trainer .set_eval_model (model_eval )
399
405
actor_trainer .timers = self .timers
@@ -688,6 +694,8 @@ def prediction_step(
688
694
}
689
695
generated_seq = self .actor_trainer .generate_sequences (prompt_only_batch , do_eval = True )[0 ]["input_ids" ]
690
696
697
+ if self .reshard_controller is not None :
698
+ self .reshard_controller .set_train_env ("[after prediction_step]" )
691
699
if not self .args .use_rm_server :
692
700
if self ._model_config .sequence_parallel :
693
701
# pad to max_sequence_length
@@ -1386,7 +1394,6 @@ def train(
1386
1394
self .control = self .callback_handler .on_step_begin (args , self .state , self .control )
1387
1395
# step 1-1: rollout data with actor model (eval) and reward model
1388
1396
self .set_eval ()
1389
-
1390
1397
data_trans_group = getattr (self .actor_trainer , "_data_trans_group" , None )
1391
1398
prompt_only_batch = data_group_split (prompt_only_batch , group = data_trans_group )
1392
1399
@@ -1415,6 +1422,7 @@ def train(
1415
1422
RolloutStages .ACTOR_MODEL_ENABLE_DISABLE ,
1416
1423
minus_names = [RolloutStages .GENERATE ],
1417
1424
)
1425
+
1418
1426
timer_scope_actor_model .start ()
1419
1427
with reload_and_offload_scope (self , self .actor_model ):
1420
1428
timer_scope_rollout = TimerScope (self .timers , RolloutStages .GENERATE )
@@ -1438,6 +1446,8 @@ def train(
1438
1446
self .timers and (dist .get_world_size () > 1 ) and dist .barrier ()
1439
1447
timer_scope_rollout .stop ()
1440
1448
timer_scope_actor_model .stop ()
1449
+ if self .reshard_controller is not None :
1450
+ self .reshard_controller .set_train_env ("[after rollout]" )
1441
1451
1442
1452
# step 2-1: truncate data
1443
1453
truncate_input_ids = [
@@ -1469,19 +1479,22 @@ def train(
1469
1479
),
1470
1480
}
1471
1481
1482
+ batch = data_group_merge (batch , group = data_trans_group )
1483
+
1472
1484
# step 2-2: balance batches based on batch tokens
1473
1485
if self .args .balance_batch :
1474
1486
batch = self ._balance_batch (batch )
1475
1487
1488
+ # step 2-3: compute logprob for rollout data
1476
1489
with self .autocast_smart_context_manager ():
1477
- # step 2-3: compute logprob for rollout data
1478
1490
with TimerScope (self .timers , RolloutStages .ROLLOUT_LOGPROB ):
1479
1491
with reload_and_offload_scope (self , self .reference_model ):
1480
1492
with TimerScope (self .timers , RolloutStages .ROLLOUT_REF_LOGPROB ):
1481
1493
batch ["ref_log_probs" ] = self .reference_trainer .compute_logprob (** batch )
1482
1494
1483
1495
with reload_and_offload_scope (self , self .actor_model ):
1484
1496
with TimerScope (self .timers , RolloutStages .ROLLOUT_OLD_LOGPROB ):
1497
+ self .actor_trainer .model .eval ()
1485
1498
batch ["log_probs" ] = self .actor_trainer .compute_logprob (** batch )
1486
1499
1487
1500
# step 2-2: compute reward for rollout data
@@ -1629,8 +1642,6 @@ def train(
1629
1642
else :
1630
1643
batch = batch
1631
1644
1632
- batch = data_group_merge (batch , group = data_trans_group )
1633
-
1634
1645
# step 3: train actor model and critic model with rollout data
1635
1646
self .set_train ()
1636
1647
with TimerScope (self .timers , ActorStages .MODEL_ENABLE_DISABLE , minus_names = [ActorStages .RL_STEP ]):
0 commit comments