Skip to content

Commit c654d1a

Browse files
authored
merge reshard-1 (#10606)
1 parent ddcb722 commit c654d1a

File tree

6 files changed

+212
-174
lines changed

6 files changed

+212
-174
lines changed

llm/alignment/rl/run_rl.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
TrainingArguments,
3333
)
3434
from paddlenlp.rl.utils.offload_utils import offload_tensor_to_cpu
35-
from paddlenlp.rl.utils.reshard_utils import init_rollout_env
35+
from paddlenlp.rl.utils.reshard_utils import ReshardController
3636
from paddlenlp.rl.utils.timer_utils import timers_scope_runtimer
3737
from paddlenlp.trainer import (
3838
EarlyStoppingCallback,
@@ -81,6 +81,7 @@ def create_actor_models(
8181
data_args: DataArgument,
8282
training_args: TrainingArguments,
8383
common_config: Dict,
84+
reshard_controller: ReshardController = None,
8485
):
8586
with timers_scope_runtimer("Actor model loading time"):
8687
# actor model
@@ -103,7 +104,7 @@ def create_actor_models(
103104
actor_model_config.use_sparse_head_and_loss_fn = False
104105
actor_model_config.seq_length = data_args.max_length
105106
actor_model_config.max_sequence_length = data_args.max_length
106-
print(f"Loading Actor model with config:\n\t{actor_model_config}\n")
107+
logger.info(f"Loading Actor model with config:\n\t{actor_model_config}\n")
107108

108109
if not training_args.autotuner_benchmark:
109110
actor_model = AutoModelForCausalLM.from_pretrained(
@@ -113,18 +114,16 @@ def create_actor_models(
113114
actor_model = AutoModelForCausalLM.from_config(actor_model_config)
114115

115116
with timers_scope_runtimer("Actor eval model loading time"):
116-
if (
117-
training_args.rollout_tensor_parallel_degree != training_args.tensor_parallel_degree
118-
or training_args.pipeline_parallel_degree > 1
119-
):
117+
if reshard_controller is not None:
118+
reshard_controller.set_rollout_env("[create actor eval model]")
120119
actor_eval_model_config = copy.deepcopy(actor_model_config)
121120
actor_eval_model_config.use_fused_head_and_loss_fn = False
122-
with init_rollout_env(training_args.rollout_tensor_parallel_degree):
123-
hcg = fleet.get_hybrid_communicate_group()
124-
actor_eval_model_config.tensor_parallel_degree = hcg.get_model_parallel_world_size()
125-
actor_eval_model_config.tensor_parallel_rank = hcg.get_model_parallel_rank()
126-
# TODO(gongenlei): lazy load lazy guard
127-
actor_eval_model = AutoModelForCausalLM.from_config(actor_eval_model_config)
121+
hcg = fleet.get_hybrid_communicate_group()
122+
actor_eval_model_config.tensor_parallel_degree = hcg.get_model_parallel_world_size()
123+
actor_eval_model_config.tensor_parallel_rank = hcg.get_model_parallel_rank()
124+
# TODO(gongenlei): lazy load lazy guard
125+
actor_eval_model = AutoModelForCausalLM.from_config(actor_eval_model_config)
126+
reshard_controller.set_train_env("[after create actor eval model]")
128127
else:
129128
actor_eval_model = None
130129

@@ -171,7 +170,7 @@ def create_reward_models(
171170
LlmMetaConfig.set_llm_config(reward_model_config, training_args)
172171
reward_model_config.max_position_embeddings = data_args.max_length
173172
reward_model_config.use_sparse_head_and_loss_fn = False
174-
print(f"Loading Reward model with config:\n\t{reward_model_config}\n")
173+
logger.info(f"Loading Reward model with config:\n\t{reward_model_config}\n")
175174

176175
config = copy.deepcopy(reward_model_config)
177176
if training_args.eval_mode is not None:
@@ -323,8 +322,16 @@ def main():
323322
max_sequence_length=data_args.max_length,
324323
)
325324

325+
if (
326+
training_args.rollout_tensor_parallel_degree != training_args.tensor_parallel_degree
327+
or training_args.pipeline_parallel_degree > 1
328+
):
329+
reshard_controller = ReshardController(tensor_parallel_degree=training_args.rollout_tensor_parallel_degree)
330+
else:
331+
reshard_controller = None
332+
326333
actor_model, actor_eval_model, reference_model, actor_tokenizer = create_actor_models(
327-
model_args, data_args, training_args, common_config
334+
model_args, data_args, training_args, common_config, reshard_controller
328335
)
329336

330337
if not training_args.use_rm_server and model_args.reward_model_name_or_path is not None:
@@ -387,6 +394,7 @@ def compute_metrics(eval_preds):
387394
), # NOTE: enforce prompt padding to max_prompt_len when using balance_batch
388395
compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets)
389396
generation_config=generation_config,
397+
reshard_controller=reshard_controller,
390398
)
391399

392400
# TODO(gongenlei) resume_from_checkpoint is not ready

paddlenlp/rl/trainer/ppo_trainer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
)
8282
from ..utils.infer_utils import infer_guard
8383
from ..utils.offload_utils import reload_and_offload_scope, reload_tensor_to_gpu
84+
from ..utils.reshard_utils import ReshardController
8485
from ..utils.timer_utils import TimerScope, TimerScopeManualLabel
8586
from .actor_trainer import ActorReferenceTrainer
8687
from .critic_trainer import CriticTrainer
@@ -232,6 +233,7 @@ def __init__(
232233
optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None),
233234
preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None,
234235
generation_config: Optional[GenerationConfig] = None,
236+
reshard_controller: Optional[ReshardController] = None,
235237
):
236238
"""
237239
Args:
@@ -282,6 +284,7 @@ def __init__(
282284
preprocess_logits_for_metrics,
283285
)
284286

287+
self.reshard_controller = reshard_controller
285288
trainer_agrs = {
286289
# "model": None,
287290
"criterion": criterion,
@@ -300,6 +303,7 @@ def __init__(
300303
model=actor_model,
301304
model_eval=actor_model_eval,
302305
tokenizer=actor_tokenizer,
306+
reshard_controller=reshard_controller,
303307
**trainer_agrs,
304308
)
305309

@@ -379,6 +383,7 @@ def create_actor_trainer(
379383
callbacks: Optional[List[TrainerCallback]] = None,
380384
optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None),
381385
preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None,
386+
reshard_controller: Optional[ReshardController] = None,
382387
):
383388
policy_training_args = copy.deepcopy(args)
384389
lr_scheduler = self.get_scheduler(policy_training_args)
@@ -394,6 +399,7 @@ def create_actor_trainer(
394399
callbacks,
395400
[None, lr_scheduler],
396401
preprocess_logits_for_metrics,
402+
reshard_controller,
397403
)
398404
actor_trainer.set_eval_model(model_eval)
399405
actor_trainer.timers = self.timers
@@ -688,6 +694,8 @@ def prediction_step(
688694
}
689695
generated_seq = self.actor_trainer.generate_sequences(prompt_only_batch, do_eval=True)[0]["input_ids"]
690696

697+
if self.reshard_controller is not None:
698+
self.reshard_controller.set_train_env("[after prediction_step]")
691699
if not self.args.use_rm_server:
692700
if self._model_config.sequence_parallel:
693701
# pad to max_sequence_length
@@ -1386,7 +1394,6 @@ def train(
13861394
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
13871395
# step 1-1: rollout data with actor model (eval) and reward model
13881396
self.set_eval()
1389-
13901397
data_trans_group = getattr(self.actor_trainer, "_data_trans_group", None)
13911398
prompt_only_batch = data_group_split(prompt_only_batch, group=data_trans_group)
13921399

@@ -1415,6 +1422,7 @@ def train(
14151422
RolloutStages.ACTOR_MODEL_ENABLE_DISABLE,
14161423
minus_names=[RolloutStages.GENERATE],
14171424
)
1425+
14181426
timer_scope_actor_model.start()
14191427
with reload_and_offload_scope(self, self.actor_model):
14201428
timer_scope_rollout = TimerScope(self.timers, RolloutStages.GENERATE)
@@ -1438,6 +1446,8 @@ def train(
14381446
self.timers and (dist.get_world_size() > 1) and dist.barrier()
14391447
timer_scope_rollout.stop()
14401448
timer_scope_actor_model.stop()
1449+
if self.reshard_controller is not None:
1450+
self.reshard_controller.set_train_env("[after rollout]")
14411451

14421452
# step 2-1: truncate data
14431453
truncate_input_ids = [
@@ -1469,19 +1479,22 @@ def train(
14691479
),
14701480
}
14711481

1482+
batch = data_group_merge(batch, group=data_trans_group)
1483+
14721484
# step 2-2: balance batches based on batch tokens
14731485
if self.args.balance_batch:
14741486
batch = self._balance_batch(batch)
14751487

1488+
# step 2-3: compute logprob for rollout data
14761489
with self.autocast_smart_context_manager():
1477-
# step 2-3: compute logprob for rollout data
14781490
with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB):
14791491
with reload_and_offload_scope(self, self.reference_model):
14801492
with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB):
14811493
batch["ref_log_probs"] = self.reference_trainer.compute_logprob(**batch)
14821494

14831495
with reload_and_offload_scope(self, self.actor_model):
14841496
with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB):
1497+
self.actor_trainer.model.eval()
14851498
batch["log_probs"] = self.actor_trainer.compute_logprob(**batch)
14861499

14871500
# step 2-2: compute reward for rollout data
@@ -1629,8 +1642,6 @@ def train(
16291642
else:
16301643
batch = batch
16311644

1632-
batch = data_group_merge(batch, group=data_trans_group)
1633-
16341645
# step 3: train actor model and critic model with rollout data
16351646
self.set_train()
16361647
with TimerScope(self.timers, ActorStages.MODEL_ENABLE_DISABLE, minus_names=[ActorStages.RL_STEP]):

paddlenlp/rl/trainer/rl_trainer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from ...utils.env import TRAINER_STATE_NAME
5050
from ..models.ppo_model_utils import create_loss
5151
from ..utils.comm_utils import create_data_trans_group
52-
from ..utils.reshard_utils import init_rollout_env
52+
from ..utils.reshard_utils import ReshardController
5353

5454
# ########## patches for Trianer ##########
5555

@@ -537,6 +537,7 @@ def __init__(
537537
callbacks: Optional[List[TrainerCallback]] = None,
538538
optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None),
539539
preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None,
540+
reshard_controller: Optional[ReshardController] = None,
540541
):
541542
super().__init__(
542543
model,
@@ -565,6 +566,7 @@ def __init__(
565566
self.ema_beta = getattr(args, "ema_beta", 0.992)
566567
# if self.timers:
567568
# self.timers.log = types.MethodType(new_timer_log, self.timers)
569+
self.reshard_controller = reshard_controller
568570

569571
def create_criterion(self):
570572
"""
@@ -595,12 +597,15 @@ def set_eval_model(self, model):
595597
dp_group = hcg.get_data_parallel_group()
596598
global_rank = dist.get_rank()
597599
old_dp_workers = self.args.world_size // (max(sd_group.nranks, 1) * max(dp_group.nranks, 1))
598-
with init_rollout_env(self.args.rollout_tensor_parallel_degree):
599-
hcg = fleet.get_hybrid_communicate_group()
600-
tensor_parallel_degree = hcg.get_model_parallel_world_size()
601-
tensor_parallel_rank = hcg.get_model_parallel_rank()
602-
eval_tp_size = max(tensor_parallel_degree, 1)
603-
eval_tp_rank = max(tensor_parallel_rank, 0)
600+
if self.reshard_controller is not None:
601+
self.reshard_controller.set_rollout_env("[set eval model]")
602+
hcg = fleet.get_hybrid_communicate_group()
603+
tensor_parallel_degree = hcg.get_model_parallel_world_size()
604+
tensor_parallel_rank = hcg.get_model_parallel_rank()
605+
if self.reshard_controller is not None:
606+
self.reshard_controller.set_train_env("[after set eval model]")
607+
eval_tp_size = max(tensor_parallel_degree, 1)
608+
eval_tp_rank = max(tensor_parallel_rank, 0)
604609
group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank
605610
self._data_trans_group = create_data_trans_group(global_rank, group_nums)
606611
# just for compatible with old code

paddlenlp/rl/utils/comm_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...trainer.trainer import Trainer, logger
2626
from ...utils.nested import flatten_list, nested_broadcast_tensor_with_empty
2727
from ..models.ppo_model_utils import make_position_ids_from_input_ids
28-
from .reshard_utils import init_reshard_mappings, init_rollout_env, reshard_to_rollout
28+
from .reshard_utils import init_reshard_mappings, reshard_to_rollout
2929

3030
global_dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device().split(":")[1])
3131

@@ -622,15 +622,18 @@ def export_evaluate_model(self: Trainer, train_model, eval_model, **kwargs):
622622
if not hasattr(self, "global_meta_dict") or self.global_meta_dict is None:
623623
self.global_meta_dict = init_reshard_mappings(train_model, self.args, pp_rank, pp_group)
624624

625-
with init_rollout_env(self.args.rollout_tensor_parallel_degree):
626-
hcg = fleet.get_hybrid_communicate_group()
627-
tensor_parallel_degree = hcg.get_model_parallel_world_size()
628-
tensor_parallel_rank = hcg.get_model_parallel_rank()
629-
eval_tp_size = max(tensor_parallel_degree, 1)
630-
eval_tp_rank = max(tensor_parallel_rank, 0)
631-
reshard_to_rollout(
632-
train_model, eval_model, self.global_meta_dict, pp_rank, pp_group, hcg.get_model_parallel_group(), tp_group
633-
)
625+
if getattr(self, "reshard_controller", None) is not None:
626+
self.reshard_controller.set_rollout_env("[export_evaluate_model]")
627+
hcg = fleet.get_hybrid_communicate_group()
628+
tensor_parallel_degree = hcg.get_model_parallel_world_size()
629+
tensor_parallel_rank = hcg.get_model_parallel_rank()
630+
eval_tp_size = max(tensor_parallel_degree, 1)
631+
eval_tp_rank = max(tensor_parallel_rank, 0)
632+
reshard_to_rollout(
633+
train_model, eval_model, self.global_meta_dict, pp_rank, pp_group, hcg.get_model_parallel_group(), tp_group
634+
)
635+
if getattr(self, "reshard_controller", None) is not None:
636+
self.reshard_controller.set_train_env("[after export_evaluate_model]")
634637

635638
old_dp_workers = self.args.world_size // (max(sd_group.nranks, 1) * max(dp_group.nranks, 1))
636639
group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank

0 commit comments

Comments
 (0)