From 508857d179258cda7f5e46db125f9050e7b308be Mon Sep 17 00:00:00 2001 From: fujinji Date: Mon, 11 Aug 2025 10:44:04 +0800 Subject: [PATCH] optimize generate, model export; support hybridflow; support sp and fused_qkv --- llm/alignment/rl/run_rl.py | 15 +- paddlenlp/datasets/rlhf_datasets/protocol.py | 1 + .../transformers/qwen2/modeling.py | 1 + paddlenlp/rl/trainer/actor_trainer.py | 1 + paddlenlp/rl/trainer/ppo_trainer.py | 159 ++++- paddlenlp/rl/trainer/rl_trainer.py | 20 +- paddlenlp/rl/utils/comm_utils.py | 86 ++- paddlenlp/rl/utils/config_utils.py | 9 + paddlenlp/rl/utils/infer_utils.py | 18 + paddlenlp/rl/utils/offload_utils.py | 5 +- paddlenlp/rl/utils/reshard_utils.py | 271 ++++++-- paddlenlp/rl/utils/timer_utils.py | 1 + paddlenlp/trainer/training_args.py | 17 +- paddlenlp/trainer/unified_checkpoint/utils.py | 2 + paddlenlp/transformers/conversion_utils.py | 29 +- tests/transformers/test_refined_recompute.py | 628 ------------------ 16 files changed, 546 insertions(+), 717 deletions(-) delete mode 100644 tests/transformers/test_refined_recompute.py diff --git a/llm/alignment/rl/run_rl.py b/llm/alignment/rl/run_rl.py index 6983dd8dfd7b..f0d32cc13cf0 100644 --- a/llm/alignment/rl/run_rl.py +++ b/llm/alignment/rl/run_rl.py @@ -55,6 +55,7 @@ def process_args(model_args: ModelArgument, data_args: DataArgument, training_ar training_args.max_src_len = data_args.max_prompt_len training_args.actor_model_name_or_path = model_args.actor_model_name_or_path training_args.max_length = data_args.max_length + # training_args.hybrid_parallel_topo_order = "mp_first" if training_args.use_rm_server: if model_args.reward_server is None: @@ -319,11 +320,21 @@ def main(): max_sequence_length=data_args.max_length, ) + # import random + # random.seed(training_args.seed) + # paddle.seed(training_args.seed) + # print(f"Fu set seed:{training_args.seed}") + gather_in_micro_dp = training_args.gather_in_micro_dp # 修改入口 + use_export_only_rollout = training_args.use_export_only_rollout # 修改入口 + print(f"Fu gather_in_micro_dp:{gather_in_micro_dp}, use_export_only_rollout:{use_export_only_rollout}") if ( training_args.rollout_tensor_parallel_degree != training_args.tensor_parallel_degree or training_args.pipeline_parallel_degree > 1 ): - reshard_controller = ReshardController(tensor_parallel_degree=training_args.rollout_tensor_parallel_degree) + reshard_controller = ReshardController(train_tensor_parallel_degree=training_args.tensor_parallel_degree, + infer_tensor_parallel_degree=training_args.rollout_tensor_parallel_degree, + gather_in_micro_dp=gather_in_micro_dp + ) else: reshard_controller = None @@ -401,6 +412,8 @@ def compute_metrics(eval_preds): compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets) generation_config=generation_config, reshard_controller=reshard_controller, + gather_in_micro_dp=gather_in_micro_dp, + use_export_only_rollout=use_export_only_rollout, ) # TODO(gongenlei) resume_from_checkpoint is not ready diff --git a/paddlenlp/datasets/rlhf_datasets/protocol.py b/paddlenlp/datasets/rlhf_datasets/protocol.py index d7a6b0566bdb..fafe7908774e 100644 --- a/paddlenlp/datasets/rlhf_datasets/protocol.py +++ b/paddlenlp/datasets/rlhf_datasets/protocol.py @@ -1158,3 +1158,4 @@ def split_batch_into_micro_batches(self, batch_size, pad_token_id=0) -> List["Da micro_batches.append(DataProto.from_single_dict(micro_batch)) return micro_batches + diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 141bd4df1bc1..400f99c31adc 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -585,6 +585,7 @@ def concat(tensor_list, axis=-1): model_prefix = self.base_model_prefix + f".layers.{idx}" # logger.info(f"set state for layer {idx}") + unfused_state_dict = {} ln_scale = paddle.to_tensor(state_dict[f"{model_prefix}.input_layernorm.weight"]).cast( self.transformer_block.ln_scales[idx].dtype ) diff --git a/paddlenlp/rl/trainer/actor_trainer.py b/paddlenlp/rl/trainer/actor_trainer.py index c1e03b27939b..92d260712bb8 100644 --- a/paddlenlp/rl/trainer/actor_trainer.py +++ b/paddlenlp/rl/trainer/actor_trainer.py @@ -61,6 +61,7 @@ def compute_logprob(self, batch: DataProto, key) -> DataProto: Raises: None. """ + # print(f"Fu compute logprob using model: {type(self.model)} {id(self.model)}") input_ids = batch.batch["input_ids"] position_ids = batch.batch["position_ids"] prompt = batch.batch.get("prompt", None) diff --git a/paddlenlp/rl/trainer/ppo_trainer.py b/paddlenlp/rl/trainer/ppo_trainer.py index 2ccd738f44df..0e20c43782c4 100644 --- a/paddlenlp/rl/trainer/ppo_trainer.py +++ b/paddlenlp/rl/trainer/ppo_trainer.py @@ -238,6 +238,7 @@ def __init__( preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None, generation_config: Optional[GenerationConfig] = None, reshard_controller: Optional[ReshardController] = None, + **kwargs, ): """ Args: @@ -318,6 +319,8 @@ def __init__( ) = self.init_train_num(self.train_dataloader) args.max_steps = self.max_steps + self.gather_in_micro_dp = kwargs.get("gather_in_micro_dp", False) + self.use_export_only_rollout = kwargs.get("use_export_only_rollout", False) self.reshard_controller = reshard_controller trainer_agrs = { @@ -334,6 +337,7 @@ def __init__( "preprocess_logits_for_metrics": preprocess_logits_for_metrics, } + self.actor_trainer = self.create_actor_trainer( model=actor_model, model_eval=actor_model_eval, @@ -401,6 +405,7 @@ def __init__( self.timers.log = types.MethodType(new_timer_log, self.timers) self.generation_config = generation_config + def create_actor_trainer( self, model: Union[PretrainedModel, nn.Layer] = None, @@ -432,6 +437,7 @@ def create_actor_trainer( [None, lr_scheduler], preprocess_logits_for_metrics, reshard_controller, + gather_in_micro_dp=self.gather_in_micro_dp, ) actor_trainer.set_eval_model(model_eval) actor_trainer.timers = self.timers @@ -719,7 +725,7 @@ def prediction_step( inputs = self._prepare_inputs(inputs) data_trans_group = getattr(self.actor_trainer, "_data_trans_group", None) inputs = data_group_split(inputs, group=data_trans_group) - with reload_and_offload_scope(self, self.actor_model, self.reference_model, self.actor_trainer): + with reload_and_offload_scope(self, self.actor_model, self.reference_model, self.actor_trainer, export_only_rollout=True): with infer_guard(self.actor_trainer): prompt_only_batch = DataProto.from_single_dict( { @@ -1349,12 +1355,23 @@ def train( step = -1 for prompt_only_batch_dict in self.prompt_only_dataloader: + timer_batch = TimerScope( + self.timers, + RolloutStages.BATCH, + ) + timer_batch.start() + prompt_only_batch: DataProto = DataProto.from_single_dict(prompt_only_batch_dict) self.control = self.callback_handler.on_step_begin(args, self.state, self.control) # step 1-1: rollout data with actor model (eval) and reward model self.set_eval() data_trans_group = getattr(self.actor_trainer, "_data_trans_group", None) + + # print(f"Fu data check [before data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'].shape}, {prompt_only_batch.batch['input_ids']._md5sum()}") + # print(f"Fu data check [before data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'][:,-120:-100]}") prompt_only_batch = data_group_split(prompt_only_batch, group=data_trans_group) + # print(f"Fu data check [after data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'].shape}, {prompt_only_batch.batch['input_ids']._md5sum()}") + # print(f"Fu data check [after data group split] prompt_only_batch: {prompt_only_batch.batch['input_ids'][:,-120:-100]}") eos_token_ids = llm_utils.get_eos_token_id(self.tokenizer, self.generation_config) pad_token_id = self.tokenizer.pad_token_id prompt_only_batch.meta_info = { @@ -1371,7 +1388,7 @@ def train( repeat_times=self.args.rollout_n, interleave=True ) prompt_only_batch_expand.rename("raw_prompt_len", "raw_prompt_len_expand") - + expand_prompt = prompt_only_batch_expand.batch["input_ids"] per_device_rollout_batch_size = self.args.per_device_rollout_batch_size cleanup_batches, indices, label_ids_batches = [], [], [] @@ -1381,9 +1398,9 @@ def train( RolloutStages.ACTOR_MODEL_ENABLE_DISABLE, minus_names=[RolloutStages.GENERATE], ) - + timer_scope_actor_model.start() - with reload_and_offload_scope(self, self.actor_model): + with reload_and_offload_scope(self, self.actor_model, export_only_rollout=True): timer_scope_rollout = TimerScope(self.timers, RolloutStages.GENERATE) timer_scope_rollout.start() with infer_guard(self.actor_trainer): @@ -1392,9 +1409,25 @@ def train( for i in range(0, total_batch_size, per_device_rollout_batch_size): micro_prompt_batch = prompt_only_batch[i : i + per_device_rollout_batch_size] # generate for multi batches and then disable FuseMT model + + # hcg = fleet.get_hybrid_communicate_group() + # rollout_sharding_parallel_group = hcg.get_sharding_parallel_group() + # rollout_data_parallel_group = hcg.get_data_parallel_group() + # rollout_model_parallel_group = hcg.get_model_parallel_group() + # rollout_pipeline_parallel_group = hcg.get_pipeline_parallel_group() + # print(f"Fu rollout dp group:{rollout_data_parallel_group}, rollout tp group:{rollout_model_parallel_group}, rollout sdp group:{rollout_sharding_parallel_group}") #, rollout pp group:{rollout_pipeline_parallel_group} + + # print(f"Fu data check [before generate_sequences] micro_prompt_batch: {micro_prompt_batch.batch['input_ids'].shape}, {micro_prompt_batch.batch['input_ids']._md5sum()}") + # print(f"Fu data check [before generate_sequences] micro_prompt_batch: {micro_prompt_batch.batch['input_ids'][:,300:320]}") generated_batches: List[DataProto] = self.actor_trainer.generate_sequences( micro_prompt_batch ) + + # for idx,generated_batch in enumerate(generated_batches): + # print(f"Fu generated_batch {idx}, lr padding count:{count_lr_padding(generated_batch.batch['input_ids'])}") + # print(f"Fu data check [after generate_sequences] generated_batch: {generated_batch.batch['input_ids'].shape}, {generated_batch.batch['input_ids']._md5sum()}") + # print(f"Fu data check [after generate_sequences] generated_batch: {generated_batch.batch['input_ids'][:,300:320]}") + # NOTE(drownfish19): do process for each micro_batch, prepare for split mode micro_ret = self.remove_pad_tokens_after_generate(generated_batches) micro_cleanup_batches, micro_indices, micro_label_ids_batches = micro_ret @@ -1451,7 +1484,11 @@ def train( } ) - batch = data_group_merge(batch, group=data_trans_group) + batch = data_group_merge(batch, group=data_trans_group, pad_token_id=0) #self.tokenizer.pad_token_id + # print(f"Fu data check [after data group merge] batch: {batch.batch['input_ids'].shape}, {batch.batch['input_ids']._md5sum()} lr padding count:{count_lr_padding(batch.batch['input_ids'])}") + # print(f"Fu data check [after data group merge] batch: {batch.batch['input_ids'][:,300:320]}") + + # step 2-2: balance batches based on batch tokens if self.args.balance_batch: batch = self._balance_batch(batch) @@ -1459,12 +1496,12 @@ def train( # step 2-3: compute logprob for rollout data with self.autocast_smart_context_manager(): with TimerScope(self.timers, RolloutStages.ROLLOUT_LOGPROB): - with reload_and_offload_scope(self, self.reference_model): + with reload_and_offload_scope(self, self.reference_model, export_only_rollout=not self.use_export_only_rollout): with TimerScope(self.timers, RolloutStages.ROLLOUT_REF_LOGPROB): ref_log_probs = self.reference_trainer.compute_logprob(batch, key="ref_log_probs") batch = batch.union(ref_log_probs) - with reload_and_offload_scope(self, self.actor_model): + with reload_and_offload_scope(self, self.actor_model, export_only_rollout=not self.use_export_only_rollout): with TimerScope(self.timers, RolloutStages.ROLLOUT_OLD_LOGPROB): log_probs = self.actor_trainer.compute_logprob(batch, key="log_probs") batch = batch.union(log_probs) @@ -1479,6 +1516,7 @@ def train( self, self.critic_model if self.args.rl_algorithm == "ppo" else None, self.reward_model if not self.args.use_rm_server and not self.args.use_rule_reward else None, + export_only_rollout=not self.use_export_only_rollout ): with TimerScope(self.timers, RolloutStages.ROLLOUT_REWARD_VALUE): reward_tensor = self.reward_trainer.compute_reward( @@ -1603,7 +1641,7 @@ def train( # step 3: train actor model and critic model with rollout data self.set_train() with TimerScope(self.timers, ActorStages.MODEL_ENABLE_DISABLE, minus_names=[ActorStages.RL_STEP]): - with reload_and_offload_scope(self, self.actor_model, self.actor_trainer.optimizer): + with reload_and_offload_scope(self, self.actor_model, self.actor_trainer.optimizer, export_only_rollout=not self.use_export_only_rollout): with TimerScope(self.timers, ActorStages.RL_STEP): # timer_info = {} # prepare for each micro_step micro_batches = batch.split_batch_into_micro_batches( @@ -1617,6 +1655,9 @@ def train( with TimerScopeManualLabel( self.timers, get_timer_label(ActorStages.MICRO_STEPS) + f"_{micro_step}" ): + + # print(f"Fu data check [before train step {micro_step}] micro_batch: {micro_batch.batch['input_ids'].shape}, {micro_batch.batch['input_ids']._md5sum()} lr count padding:{count_lr_padding(micro_batch.batch['input_ids'])}") + # print(f"Fu data check [before train step {micro_step}] micro_batch: {micro_batch.batch['input_ids'][:,300:320]}") rl_info = self.actor_trainer.update_actor(micro_batch) paddle.device.cuda.empty_cache() @@ -1642,12 +1683,14 @@ def train( self.control = self.callback_handler.on_substep_end(args, self.state, self.control) step += 1 - + timer_batch.stop() + self._print_timer() self._maybe_log_save_evaluate(rl_info, None, epoch, ignore_keys_for_eval, inputs=micro_batch) paddle.device.cuda.empty_cache() if self.control.should_epoch_stop or self.control.should_training_stop: break + if step < 0: logger.warning( @@ -1935,3 +1978,101 @@ def compute_advantage_normalization(self, batch: DataProto): batch.batch["reward_advantages"] = batch.batch["reward_advantages"] * batch.batch["eos_mask"] return batch + +def remove_pad_and_count_lr(tensor, pad_id=151643): + """ + 输入二维 tensor,去除所有pad_id元素,同时统计每行左、右两侧连续pad_id数量。 + + Args: + tensor (paddle.Tensor): shape = [batch_size, seq_len] + pad_id (int): 填充id + + Returns: + filtered (paddle.Tensor): 一维tensor,去除所有pad_id的有效元素 + left_pad_counts (paddle.Tensor): shape=[batch_size], 每行左侧连续pad数量 + right_pad_counts (paddle.Tensor): shape=[batch_size], 每行右侧连续pad数量 + """ + batch_size, seq_len = tensor.shape + left_pad_counts = [] + right_pad_counts = [] + + # 转numpy计算连续pad数量更方便 + tensor_np = tensor.numpy() + + for row in tensor_np: + # 左侧连续pad数量 + left_count = 0 + for val in row: + if val == pad_id: + left_count += 1 + else: + break + + # 右侧连续pad数量 + right_count = 0 + for val in row[::-1]: + if val == pad_id: + right_count += 1 + else: + break + + left_pad_counts.append(left_count) + right_pad_counts.append(right_count) + + # 转成tensor + left_pad_counts = paddle.to_tensor(left_pad_counts, dtype='int64') + right_pad_counts = paddle.to_tensor(right_pad_counts, dtype='int64') + + # 过滤掉所有pad_id元素,返回一维tensor + mask = tensor != pad_id + filtered = paddle.masked_select(tensor, mask) + + # print(f"tensor left padding:{left_pad_counts}, right paddin:{right_pad_counts}") + return filtered, left_pad_counts, right_pad_counts + + +def count_lr_padding(tensor, pad_id=151643): + """ + 统计每行左、右两侧连续pad_id数量。 + + Args: + tensor (paddle.Tensor): shape = [batch_size, seq_len] + pad_id (int): 填充id + + Returns: + filtered (paddle.Tensor): 一维tensor,去除所有pad_id的有效元素 + left_pad_counts (paddle.Tensor): shape=[batch_size], 每行左侧连续pad数量 + right_pad_counts (paddle.Tensor): shape=[batch_size], 每行右侧连续pad数量 + """ + batch_size, seq_len = tensor.shape + left_pad_counts = [] + right_pad_counts = [] + + # 转numpy计算连续pad数量更方便 + tensor_np = tensor.numpy() + + for row in tensor_np: + # 左侧连续pad数量 + left_count = 0 + for val in row: + if val == pad_id: + left_count += 1 + else: + break + + # 右侧连续pad数量 + right_count = 0 + for val in row[::-1]: + if val == pad_id: + right_count += 1 + else: + break + + left_pad_counts.append(left_count) + right_pad_counts.append(right_count) + + # 转成tensor + left_pad_counts = paddle.to_tensor(left_pad_counts, dtype='int64') + right_pad_counts = paddle.to_tensor(right_pad_counts, dtype='int64') + + return left_pad_counts, right_pad_counts \ No newline at end of file diff --git a/paddlenlp/rl/trainer/rl_trainer.py b/paddlenlp/rl/trainer/rl_trainer.py index 97fd7d757c8b..ccadee791672 100644 --- a/paddlenlp/rl/trainer/rl_trainer.py +++ b/paddlenlp/rl/trainer/rl_trainer.py @@ -538,6 +538,7 @@ def __init__( optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None, reshard_controller: Optional[ReshardController] = None, + **kwargs ): super().__init__( model, @@ -567,6 +568,7 @@ def __init__( # if self.timers: # self.timers.log = types.MethodType(new_timer_log, self.timers) self.reshard_controller = reshard_controller + self.gather_in_micro_dp = kwargs.get("gather_in_micro_dp", False) def create_criterion(self): """ @@ -600,14 +602,27 @@ def set_eval_model(self, model): if self.reshard_controller is not None: self.reshard_controller.set_rollout_env("[set eval model]") hcg = fleet.get_hybrid_communicate_group() + new_dp, new_sdp = hcg.get_data_parallel_group().nranks, hcg.get_sharding_parallel_group().nranks tensor_parallel_degree = hcg.get_model_parallel_world_size() tensor_parallel_rank = hcg.get_model_parallel_rank() if self.reshard_controller is not None: self.reshard_controller.set_train_env("[after set eval model]") eval_tp_size = max(tensor_parallel_degree, 1) eval_tp_rank = max(tensor_parallel_rank, 0) - group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank - self._data_trans_group = create_data_trans_group(global_rank, group_nums) + + gather_in_micro_dp = self.gather_in_micro_dp + if not gather_in_micro_dp: + group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank + self._data_trans_group = create_data_trans_group(global_rank, group_nums) + else: + # new_dp_workers = self.args.world_size // (max(new_sdp, 1) * max(new_dp, 1)) + # group_nums = self.args.logical_process_index // new_dp_workers + # self._data_trans_group = create_data_trans_group(global_rank, group_nums) + if self.reshard_controller is not None: + self._data_trans_group = self.reshard_controller.micro_dp_group + else: + group_nums = self.args.logical_process_index + self._data_trans_group = create_data_trans_group(global_rank, group_nums) # just for compatible with old code self._policy_model_eval_group = self._data_trans_group @@ -620,6 +635,7 @@ def get_model(self, train=False): return self.model_wrapped model = getattr(self, "_eval_model", None) if model is not None: + # print(f"Fu get trainer's eval model: {id(model)}") return model inner_eval_model = getattr(self, "_inner_eval_model", None) if (self.args.pipeline_parallel_degree > 1 and inner_eval_model is None) or isinstance( diff --git a/paddlenlp/rl/utils/comm_utils.py b/paddlenlp/rl/utils/comm_utils.py index 6b65f63fea73..a1179334cd43 100644 --- a/paddlenlp/rl/utils/comm_utils.py +++ b/paddlenlp/rl/utils/comm_utils.py @@ -206,6 +206,8 @@ class RolloutStages(Enum): REWARD_MODEL_ENABLE_DISABLE = auto() ROLLOUT_REWARD_VALUE = auto() ROLLOUT_ADVANTAGE = auto() + RESHARD = auto() + BATCH = auto() # 一个batch处理用时(rollout->make experience-> train) def get_timer_label(stage: Enum) -> str: @@ -236,6 +238,8 @@ def get_timer_label(stage: Enum) -> str: RolloutStages.ROLLOUT_ADVANTAGE: "rollout", RolloutStages.REWARD_MODEL_ENABLE_DISABLE: "rollout", RolloutStages.ROLLOUT_REWARD_VALUE: "rollout", + RolloutStages.RESHARD: "rollout", + RolloutStages.BATCH:"batch" } # stage prefix = step_prefix.get(stage, "unknown") @@ -294,12 +298,14 @@ def data_group_split(tensors, group): return new_dict elif isinstance(tensors, paddle.Tensor): return tensors.split(group.nranks)[group.rank] + elif isinstance(tensors, DataProto): + return tensors.split_batch_into_micro_batches(batch_size=len(tensors)//group.nranks)[group.rank] else: logger.debug(f"[data_group_split]Can't parse for type {type(tensors)}") return tensors -def data_group_merge(tensors, group): +def data_group_merge(tensors, group, pad_token_id=None): """ Combine data into a new list or dictionary, or perform all_gather_nd operation in the specified group if not None. @@ -333,6 +339,25 @@ def data_group_merge(tensors, group): tensor_list = [] all_gather_nd(tensor_list, tensors, group=group, padded=True) return np.concatenate(tensor_list) + elif isinstance(tensors, DataProto): + tensor_list = [] + # all_gather_object(tensor_list, tensors, group=group) + + tensors_to_gather_per_key = defaultdict(list) + for key in tensors.batch.keys(): + tensors_to_gather_per_key[key].append(tensors.batch[key]) + for key in tensors.non_tensor_batch.keys(): + tensors_to_gather_per_key[key].append(tensors.non_tensor_batch[key]) + + global_balanced_batch_dict = {} + # Collect and pad tensors from all workers (across DP and Sharding groups) + for key in tensors_to_gather_per_key.keys(): + tensor_list_from_local_batch = tensors_to_gather_per_key[key] + + global_balanced_batch_dict[key] = gather_tensor_list(group, None)( + DataProto.pad_or_concat_tensor_list + )(tensor_list_from_local_batch, pad_token_id, key) + return DataProto.from_single_dict(global_balanced_batch_dict) else: logger.debug(f"[data_group_merge]Can't parse for type {type(tensors)}") return tensors @@ -621,28 +646,67 @@ def export_evaluate_model(self: Trainer, train_model, eval_model, **kwargs): dp_group = hcg.get_data_parallel_group() pp_rank = hcg.get_stage_id() - if not hasattr(self, "global_meta_dict") or self.global_meta_dict is None: - self.global_meta_dict = init_reshard_mappings(train_model, self.args, pp_rank, pp_group) if getattr(self, "reshard_controller", None) is not None: self.reshard_controller.set_rollout_env("[export_evaluate_model]") hcg = fleet.get_hybrid_communicate_group() tensor_parallel_degree = hcg.get_model_parallel_world_size() + new_dp, new_sdp = hcg.get_data_parallel_group().nranks, hcg.get_sharding_parallel_group().nranks tensor_parallel_rank = hcg.get_model_parallel_rank() + # rollout_dp_group = hcg.get_data_parallel_group() + # print(f"Fu rank:{paddle.distributed.get_rank()}, rollout_dp_group:{rollout_dp_group}, rollout_tp_group:{hcg.get_model_parallel_group()}") eval_tp_size = max(tensor_parallel_degree, 1) eval_tp_rank = max(tensor_parallel_rank, 0) - reshard_to_rollout( - train_model, eval_model, self.global_meta_dict, pp_rank, pp_group, hcg.get_model_parallel_group(), tp_group - ) + + if not hasattr(self, "global_meta_dict") or self.global_meta_dict is None: + self.global_meta_dict = init_reshard_mappings(train_model, self.args, pp_rank, pp_group, hcg.get_model_parallel_group()) + + gather_in_micro_dp = self.gather_in_micro_dp + + # print(f"Fu before reshard") + # print(f"Fu train model weight `train_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum`: {train_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum()}") + # print(f"FU eval model weight `eval_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum`: {eval_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum()}") + # if not kwargs.get('hybridflow', False): # 未优化前 + if not gather_in_micro_dp: + reshard_to_rollout( + train_model, eval_model, self.global_meta_dict, pp_rank, pp_group, + rollout_tp_group=hcg.get_model_parallel_group(), + train_tp_group=tp_group, + gather_in_micro_dp=False + ) + else: # 优化后,权重只用在micro dp group中进行聚合即可 + reshard_to_rollout( + train_model, eval_model, self.global_meta_dict, pp_rank, pp_group, + micro_dp_group=self.reshard_controller.micro_dp_group, + train_tp_group=tp_group, + rollout_tp=tensor_parallel_degree, + train_tp=tp_group.nranks, + gather_in_micro_dp=True, + ) if getattr(self, "reshard_controller", None) is not None: self.reshard_controller.set_train_env("[after export_evaluate_model]") - old_dp_workers = self.args.world_size // (max(sd_group.nranks, 1) * max(dp_group.nranks, 1)) - group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank + # print(f"Fu after reshard") + # print(f"Fu train model weight `train_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum`: {train_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum()}") + # print(f"FU eval model weight `eval_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum`: {eval_model.qwen2.layers[0].mlp.down_proj.weight.data._md5sum()}") + + if not gather_in_micro_dp: + # rank 0 1 2 3 4 5 6 7 + # train tp id 0 0 0 0 2 2 2 2 + # tp rank 0 1 0 1 0 1 0 1 + # ans 0 1 0 1 2 3 2 3 + old_dp_workers = self.args.world_size // (max(sd_group.nranks, 1) * max(dp_group.nranks, 1)) + group_nums = self.args.logical_process_index // old_dp_workers * eval_tp_size + eval_tp_rank + else: + # rank 0 1 2 3 4 5 6 7 + # ans 0 0 1 1 2 2 3 3 + new_dp_workers = self.args.world_size // (max(new_sdp, 1) * max(new_dp, 1)) + group_nums = self.args.logical_process_index // new_dp_workers if not hasattr(self, "_policy_model_eval_group") or self._policy_model_eval_group is None: self._policy_model_eval_group = create_data_trans_group(paddle.distributed.get_rank(), group_nums) + # print(f"Fu micro dp group/_policy_model_eval_group:{self._policy_model_eval_group}") return None @@ -666,7 +730,7 @@ def create_data_trans_group(global_rank, group_nums): for k, v in all_split_table: split_dict[k] = v - split_ranks = {} + split_ranks = {} # group idx:rank idx list for k, v in all_split_table: if v in split_ranks: split_ranks[v].append(k) @@ -998,7 +1062,7 @@ def gather_tensor(tensor, dp_group=None, sd_group=None): dtype = tensor[0].dtype - if (dp_group is None and sd_group is None) or (dp_group.nranks == 1 and sd_group.nranks == 1): + if (dp_group is None and sd_group is None) or (getattr(dp_group,'nranks',1) == 1 and getattr(sd_group,'nranks',1) == 1): return tensor def map_func(weight): @@ -1009,7 +1073,7 @@ def map_func(weight): tensor = [map_func(i) for i in tensor] sd_gathered_tensor = [] - if sd_group.nranks > 1: + if sd_group is not None and sd_group.nranks > 1: dist.all_gather_object(sd_gathered_tensor, tensor, group=sd_group) dp_gathered_tensor = [] diff --git a/paddlenlp/rl/utils/config_utils.py b/paddlenlp/rl/utils/config_utils.py index 9aae90c9001b..a5f65de64432 100644 --- a/paddlenlp/rl/utils/config_utils.py +++ b/paddlenlp/rl/utils/config_utils.py @@ -26,6 +26,15 @@ @dataclass @llmmetaclass class TrainingArguments(TrainingArguments): + gather_in_micro_dp: bool = field( + default=False, + metadata={"help": "Whether to gather gradients in micro dp mode."}, + ) + use_export_only_rollout: bool = field( + default=False, + metadata={"help": "Whether to use export only rollout."}, + ) + global_batch_size: int = field( default=8, metadata={"help": "Global batch size for input prompt."}, diff --git a/paddlenlp/rl/utils/infer_utils.py b/paddlenlp/rl/utils/infer_utils.py index f295306d8555..c63b68409eb6 100644 --- a/paddlenlp/rl/utils/infer_utils.py +++ b/paddlenlp/rl/utils/infer_utils.py @@ -17,6 +17,7 @@ import copy import inspect from contextlib import contextmanager +import time import paddle import paddle.distributed as dist @@ -32,6 +33,8 @@ from ...transformers.model_utils import dtype_guard from ..trainer.trainer_utils import process_row from .offload_utils import offload_tensor_to_cpu, reload_tensor_to_gpu +from .timer_utils import TimerScope +from .comm_utils import RolloutStages try: from llm.predict.predictor import ( @@ -95,6 +98,11 @@ def predict(self, input_ids: paddle.Tensor = None, repeat_num=1, **kwargs): input_ids_list.append(row_ids) if self.config.dynamic_insert: + # input_tensor_list = [] + # for input_ids_item in input_ids_list: + # input_tensor = paddle.to_tensor(input_ids_item,dtype=input_ids.dtype) + # input_tensor_list.append(input_tensor) + # print(f"Fu input_ids_list: {[input_tensor.shape for input_tensor in input_tensor_list]} {[input_tensor._md5sum() for input_tensor in input_tensor_list]}") outputs = self.predict_dy_insert( input_ids=input_ids_list, return_tokens=True, @@ -103,6 +111,8 @@ def predict(self, input_ids: paddle.Tensor = None, repeat_num=1, **kwargs): repeat_num=repeat_num, **kwargs, )[-1] + # out_tensor = paddle.to_tensor(outputs, dtype=input_ids.dtype) + # print(f"Fu out_tensor: {out_tensor.shape} {out_tensor._md5sum()}") return paddle.to_tensor(outputs, dtype=input_ids.dtype) else: raise NotImplementedError("dynamic_insert is False is not supported.") @@ -258,16 +268,24 @@ def __init__(self, trainer: Trainer): def enable(self): trainer = self.trainer if trainer.model is not self.model: + # print(f"Fu export to evaluate model from {type(trainer.model)} {id(trainer.model)} to {type(self.model)} {id(self.model)}") reload_tensor_to_gpu((trainer.model, "train_model")) reload_tensor_to_gpu((self.model, "freeze_model")) + timer_scope_reshard = TimerScope(trainer.timers, RolloutStages.RESHARD) + timer_scope_reshard.start() + # begin_reshard_time = time.time() trainer.export_evaluate_model( trainer.model, self.model, with_offload="train_model" in trainer.args.offload_level, ) + # end_reshard_time = time.time() + # logger.info(f"Fu Reshard time cost: {end_reshard_time - begin_reshard_time:.4f}s") + timer_scope_reshard.stop() # NOTE(gongenlei): Add offload offload_tensor_to_cpu((trainer.model, "train_model")) else: + # print(f"Fu DO NOT export model because trainer.model{type(trainer.model)} {id(trainer.model)} == InferEvelModel.model {type(self.model)} {id(self.model)}") reload_tensor_to_gpu((self.model, "train_model")) def disable(self): diff --git a/paddlenlp/rl/utils/offload_utils.py b/paddlenlp/rl/utils/offload_utils.py index 4c37b1d38be5..701990beb970 100644 --- a/paddlenlp/rl/utils/offload_utils.py +++ b/paddlenlp/rl/utils/offload_utils.py @@ -179,7 +179,7 @@ def __exit__(self, *args): paddle.device.synchronize() -def reload_and_offload_scope(trainer, *args): +def reload_and_offload_scope(trainer, *args, **kwargs): offload_map = { trainer.actor_model: "train_model", trainer.reference_model: "freeze_model", @@ -201,8 +201,9 @@ def reload_and_offload_scope(trainer, *args): objs = [(arg, offload_map.get(arg, "")) for arg in args if offload_map.get(arg, "") in trainer.args.offload_level] if trainer.actor_model not in [i for i, _ in objs]: - if getattr(trainer.actor_trainer, "_inner_eval_model", None) is not None: + if getattr(trainer.actor_trainer, "_inner_eval_model", None) is not None and kwargs.get("export_only_rollout", False): # NOTE(gongenlei): for export_evaluate_model + # print(f"Fu offlaod model {type(trainer.actor_model)} {id(trainer.actor_model)}") objs.append((trainer.actor_model, offload_map.get(trainer.actor_model, ""))) if trainer.args.rl_algorithm == "ppo": if trainer.critic_model not in [i for i, _ in objs]: diff --git a/paddlenlp/rl/utils/reshard_utils.py b/paddlenlp/rl/utils/reshard_utils.py index 8e47b0f1ba9e..2e6f6d902f78 100644 --- a/paddlenlp/rl/utils/reshard_utils.py +++ b/paddlenlp/rl/utils/reshard_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import copy import numpy as np import paddle import paddle.distributed as dist @@ -30,13 +31,16 @@ class ReshardController: def __init__( self, - tensor_parallel_degree, + train_tensor_parallel_degree, + infer_tensor_parallel_degree, # 这时候输入就已经是rollout_tensor_tp了 pipeline_parallel_degree=1, sharding_parallel_degree=1, sep_parallel_degree=1, seed=100, + **kwargs ): - self.tensor_parallel_degree = tensor_parallel_degree + self.train_tensor_parallel_degree = train_tensor_parallel_degree + self.infer_tensor_parallel_degree = infer_tensor_parallel_degree self.pipeline_parallel_degree = pipeline_parallel_degree self.sharding_parallel_degree = sharding_parallel_degree self.sep_parallel_degree = sep_parallel_degree @@ -49,31 +53,130 @@ def __init__( self.train_hcg.get_data_parallel_group(), self.train_hcg.get_sharding_parallel_group(), ) + self.gather_in_micro_dp = kwargs.get("gather_in_micro_dp", False) self.infer_tp_group, self.infer_dp_group, self.infer_sdp_group = self.init_rollout_env() self.set_train_env() self.is_train = True + def init_rollout_env(self): - world_size = dist.get_world_size() - infer_topo = CommunicateTopology( - hybrid_group_names=["data", "pipe", "sharding", "sep", "model"], - dims=[ - world_size - // self.tensor_parallel_degree - // self.pipeline_parallel_degree - // self.sharding_parallel_degree - // self.sep_parallel_degree, - self.pipeline_parallel_degree, - self.sharding_parallel_degree, - self.sep_parallel_degree, - self.tensor_parallel_degree, - ], - ) - infer_hcg = HybridCommunicateGroup(infer_topo) - infer_tp_group = infer_hcg.get_model_parallel_group() - infer_dp_group = infer_hcg.get_data_parallel_group() - infer_sdp_group = infer_hcg.get_sharding_parallel_group() + gather_in_micro_dp = self.gather_in_micro_dp + if not gather_in_micro_dp: + world_size = dist.get_world_size() + infer_topo = CommunicateTopology( + hybrid_group_names=["data", "pipe", "sharding", "sep", "model"], + dims=[ + world_size + // self.infer_tensor_parallel_degree + // self.pipeline_parallel_degree + // self.sharding_parallel_degree + // self.sep_parallel_degree, + self.pipeline_parallel_degree, + self.sharding_parallel_degree, + self.sep_parallel_degree, + self.infer_tensor_parallel_degree, + ], + ) + + infer_hcg = HybridCommunicateGroup(infer_topo) + infer_tp_group = infer_hcg.get_model_parallel_group() + infer_dp_group = infer_hcg.get_data_parallel_group() + infer_sdp_group = infer_hcg.get_sharding_parallel_group() + else: + world_size = dist.get_world_size() + micro_dp = self.train_tensor_parallel_degree // self.infer_tensor_parallel_degree + # Fu 不行,需要自己手动创建并行组 + # order: DP PP SDP SEP TP micro_DP + # DP: concat([train DP group, micro_DP group]) + # TP: interval= micro DP + # PP/SDP/SEP == train PP/SDP/SEP + + + topo_order = ["data", "pipe", "sharding", "sep", "model"] + # infer_topo = CommunicateTopology( + # hybrid_group_names=["model", "pipe", "data", "sharding", "sep"], + # dims=[ + # self.tensor_parallel_degree, + # self.pipeline_parallel_degree, + # world_size + # // self.tensor_parallel_degree + # // self.pipeline_parallel_degree + # // self.sharding_parallel_degree + # // self.sep_parallel_degree, + # self.sharding_parallel_degree, + # self.sep_parallel_degree, + # ], + # ) + + all_tp_dp_ranks = [] + # 获取所有卡的tp group,按照dp,sdp,pp group进行聚合;三次通信不如直接all gather后去重 + paddle.distributed.all_gather_object(all_tp_dp_ranks, (self.train_tp_group.ranks,self.train_dp_group.ranks)) + all_tp_ranks = [t[0] for t in all_tp_dp_ranks] + all_train_dp_ranks = [t[1] for t in all_tp_dp_ranks] + all_tp_ranks = [list(t) for t in set(tuple(sub) for sub in all_tp_ranks)] # 去重 + all_train_dp_ranks = [list(t) for t in set(tuple(sub) for sub in all_train_dp_ranks)] + + + rank = paddle.distributed.get_rank() + # 在train TP group中,划分micro dp group和rollout TP group + # for micro_dp_idx in range(micro_dp): # 这个是micro dp, rollout t + # for rollout_tp_idx in range(self.tensor_parallel_degree): + # idx = micro_dp_idx * micro_dp + rollout_tp_idxp + train_ranks = self.train_tp_group.ranks + micro_group_lst, infer_tp_group_lst = [], [] + for tp_ranks in all_tp_ranks: + # micro_dp_group + for rollout_tp_idx in range(self.infer_tensor_parallel_degree): + micro_rank_lst = [] + for micro_dp_idx in range(micro_dp): + idx = rollout_tp_idx * micro_dp + micro_dp_idx + micro_rank_lst.append(tp_ranks[idx]) + micro_group_lst.append(micro_rank_lst) + # infer_tp_group + for micro_dp_idx in range(micro_dp): + infer_tp_rank_lst = [] + for rollout_tp_idx in range(self.infer_tensor_parallel_degree): + idx = rollout_tp_idx * micro_dp + micro_dp_idx + infer_tp_rank_lst.append(tp_ranks[idx]) + infer_tp_group_lst.append(infer_tp_rank_lst) + + # 融合DP和micro DP + infer_dp_group_lst = [] + for micro_ranks in micro_group_lst: + infer_dp_rank_lst = [] + for micro_rank in micro_ranks: + for train_dp_ranks in all_train_dp_ranks: + if micro_rank in train_dp_ranks: + infer_dp_rank_lst.extend(train_dp_ranks) + break + infer_dp_rank_lst = list(set(infer_dp_rank_lst)) + infer_dp_group_lst.append(infer_dp_rank_lst) + infer_dp_group_lst = [list(t) for t in set(tuple(sub) for sub in infer_dp_group_lst)] + print(f"Fu infer dp group lst:{infer_dp_group_lst} micro_group_lst:{micro_group_lst}, infer tp group list:{infer_tp_group_lst}") + + + # micro_group_dict={0:[0,1],1:[2,3],2:[4,5],3:[6,7]} + for ranks in micro_group_lst: + gp = paddle.distributed.new_group(ranks=ranks) + print(f"Fu [micro dp group create] create a group {gp.id}:{gp.ranks}") + if rank in ranks: + self.micro_dp_group = gp + for ranks in infer_dp_group_lst: + print(f"Fu [before infer dp group create] begin to create a dp group:{ranks}") + gp = paddle.distributed.new_group(ranks=ranks) + print(f"Fu [infer dp group create] create a group {gp.id}:{gp.ranks}") + if rank in ranks: + infer_dp_group = gp + for ranks in infer_tp_group_lst: + gp = paddle.distributed.new_group(ranks=ranks) + print(f"Fu [infer tp group create] create a group {gp.id}:{gp.ranks}") + if rank in ranks: + infer_tp_group = gp + infer_sdp_group = self.train_sdp_group + print(f"Fu rank:{rank}, micro_dp_group:{self.micro_dp_group}, infer_dp_group:{infer_dp_group}, infer_tp_group:{infer_tp_group}") + + print(f"Fu infer tp group:{infer_tp_group}, dp group:{infer_dp_group}, sdp group:{infer_sdp_group}") return (infer_tp_group, infer_dp_group, infer_sdp_group) def _get_rng_state(self): @@ -101,7 +204,7 @@ def set_rollout_env(self, msg=""): topology._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size = lambda: self.infer_tp_group.nranks topology._HYBRID_PARALLEL_GROUP.get_model_parallel_rank = lambda: self.infer_tp_group.rank self.log(msg, False) - # self.is_train = False + self.is_train = False def set_train_env(self, msg=""): hcg = fleet.get_hybrid_communicate_group() @@ -116,7 +219,7 @@ def set_train_env(self, msg=""): paddle.set_rng_state(self.orig_rng_state) self.log(msg, True) - # self.is_train = True + self.is_train = True def log(self, msg, is_train=False): msg = f"for {msg}" if len(msg) > 0 else "" @@ -154,31 +257,73 @@ def mp_reshard( src_tensor, tgt_tensor, meta_dict, - train_tp_group, - rollout_tp_group, + **kwargs, ): - if rollout_tp_group.nranks == train_tp_group.nranks: + gather_in_micro_dp = kwargs.get("gather_in_micro_dp", False) + + if not gather_in_micro_dp: + rollout_tp_group = kwargs.get("rollout_tp_group", None) + train_tp_group = kwargs.get("train_tp_group", None) + + if rollout_tp_group.nranks == train_tp_group.nranks: + return src_tensor + + if meta_dict["is_distributed"]: + res = [] + if train_tp_group.nranks > 1: + paddle.distributed.all_gather(res, src_tensor, group=train_tp_group, sync_op=True) + else: + res = [src_tensor] + if hasattr(tgt_tensor, "is_distributed") and tgt_tensor.is_distributed: + merge_fn = meta_dict["merge_tensor_fn"] + concat_tensor = merge_fn( + res, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_navice_3fuse=False, keep_on_gpu=True + ) + del res + split_fn = meta_dict["split_tensor_fn"] + split_part = split_fn( + concat_tensor, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False + ) + del concat_tensor + return split_part + else: + merge_fn = meta_dict["merge_tensor_fn"] + concat_tensor = merge_fn( + res, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_navice_3fuse=False, keep_on_gpu=True + ) + return concat_tensor return src_tensor + else: + micro_dp_group = kwargs.get("micro_dp_group", None) + rollout_tp = kwargs.get("rollout_tp", None) + train_tp = kwargs.get("train_tp", None) - if meta_dict["is_distributed"]: - res = [] - if train_tp_group.nranks > 1: - paddle.distributed.all_gather(res, src_tensor, group=train_tp_group, sync_op=True) - else: - res = [src_tensor] - if hasattr(tgt_tensor, "is_distributed") and tgt_tensor.is_distributed: - assert hasattr(tgt_tensor, "split_axis"), f"{tgt_tensor.name} has no split_axis!" - concat_tensor = paddle.concat(res, meta_dict["split_axis"]) - del res - all_parts = paddle.split(concat_tensor, rollout_tp_group.nranks, tgt_tensor.split_axis) - del concat_tensor - return all_parts[rollout_tp_group.rank] - else: - return paddle.concat(res, meta_dict["split_axis"]) - return src_tensor + if rollout_tp==train_tp: + return src_tensor + + if meta_dict['is_distributed']: # 像input layernorm的is_distributed就是False,即为不分片(但是SP不就是给它分片) + # 训练权重按照micro dp并行组聚合。这里暂未考虑sdp情况,回头需要测试一下 + res = [] + assert train_tp>rollout_tp, "train_tp must greater than rollout_tp for larger throughtout in rollout" + + paddle.distributed.all_gather(res, src_tensor, group=micro_dp_group, sync_op=True) + merge_fn = meta_dict["merge_tensor_fn"] + concat_tensor = merge_fn( + res, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False, keep_on_gpu=True + ) + return concat_tensor + return src_tensor + # if meta_dict.get("split_axis",None) is not None: + # # print(f"meta_dict:{meta_dict}") + # paddle.distributed.all_gather(res, src_tensor, group=micro_dp_group, sync_op=True) + # # if src_tensor.ndim==1: # ???在layernorm。weigth时报错没有split_axis。为什么上面的就不会报错呢? + # # meta_dict["split_axis"] = 0 + # return paddle.concat(res, meta_dict["split_axis"]) + # # print(f"Fu meta_dict has none split_axis, {meta_dict}") + # return src_tensor # 针对layernorm的TP情况 -def init_reshard_mappings(model, training_args, pp_rank, pp_group): +def init_reshard_mappings(model, training_args, pp_rank, pp_group, rollout_tp_group): global_meta_dict = {} if training_args.pipeline_parallel_degree > 1: model._layers._set_pipeline_name_mapping() @@ -202,9 +347,6 @@ def init_reshard_mappings(model, training_args, pp_rank, pp_group): local_meta_dict[k]["is_distributed"] = False if hasattr(pipeline_tensor, "is_distributed"): local_meta_dict[k]["is_distributed"] = pipeline_tensor.is_distributed - local_meta_dict[k]["split_axis"] = None - if hasattr(pipeline_tensor, "split_axis"): - local_meta_dict[k]["split_axis"] = pipeline_tensor.split_axis if training_args.pipeline_parallel_degree > 1: gathered_local_meta_dict = [] dist.all_gather_object(gathered_local_meta_dict, local_meta_dict, group=pp_group) @@ -212,12 +354,36 @@ def init_reshard_mappings(model, training_args, pp_rank, pp_group): gathered_local_meta_dict = [local_meta_dict] for meta_dict in gathered_local_meta_dict: global_meta_dict.update(meta_dict) + + if ( + training_args.tensor_parallel_degree != training_args.rollout_tensor_parallel_degree or training_args.pipeline_parallel_degree > 1 + ): + model_class = type(model) + tensor_parallel_config = copy(model.config) + tensor_parallel_config.tensor_parallel_degree = training_args.rollout_tensor_parallel_degree + tensor_parallel_config.tensor_parallel_rank = rollout_tp_group.rank + + merge_tensor_fn_dict = model_class._get_tensor_parallel_mappings(config=tensor_parallel_config, is_split=False) + for k,v in merge_tensor_fn_dict.items(): + key = k if k in global_meta_dict else f"{model.config.model_type}.{k}" + global_meta_dict[key]["merge_tensor_fn"] = v + + if training_args.rollout_tensor_parallel_degree > 1: + split_tensor_fn_dict = model_class._get_tensor_parallel_mappings( + config=tensor_parallel_config, is_split=True + ) + for k,v in split_tensor_fn_dict.items(): + key = k if k in global_meta_dict else f"{model.config.model_type}.{k}" + global_meta_dict[key]["split_tensor_fn"] = v return global_meta_dict @paddle.no_grad() +# def reshard_to_rollout( +# train_model, rollout_model, global_meta_dict, pp_rank, pp_group, rollout_tp_group, train_tp_group, gather_in_micro_dp=False +# ): def reshard_to_rollout( - train_model, rollout_model, global_meta_dict, pp_rank, pp_group, rollout_tp_group, train_tp_group + train_model, rollout_model, global_meta_dict, pp_rank, pp_group, **kwargs ): train_model_state_dict = train_model.state_dict() rollout_model_state_dict = rollout_model.state_dict() @@ -227,12 +393,19 @@ def reshard_to_rollout( for k, _ in param_numel: v = rollout_model_state_dict[k] resharded_tensor = pp_reshard(v, train_model_state_dict, global_meta_dict[k], pp_rank, pp_group) + # resharded_tensor = mp_reshard( + # resharded_tensor, + # v, + # global_meta_dict[k], + # train_tp_group, + # rollout_tp_group, # 注意当gather_in_micro_dp为True时,rollout_tp_group其实是rollout_dp_group,这里懒得改了 + # gather_in_micro_dp, + # ) resharded_tensor = mp_reshard( resharded_tensor, v, global_meta_dict[k], - train_tp_group, - rollout_tp_group, + **kwargs, ) assert resharded_tensor.dtype == v.dtype, f"dtype wrong {k} {resharded_tensor.dtype} {v.dtype}" assert resharded_tensor.shape == v.shape, f"shape wrong {k} {resharded_tensor.shape} {v.shape}" diff --git a/paddlenlp/rl/utils/timer_utils.py b/paddlenlp/rl/utils/timer_utils.py index 11e7a4280a23..16428a8c4737 100644 --- a/paddlenlp/rl/utils/timer_utils.py +++ b/paddlenlp/rl/utils/timer_utils.py @@ -46,6 +46,7 @@ def start(self) -> None: Explicitly start the timer. """ if self.timers: + # print(f"Fu Begin {self.name} , {self.label}") self.timers(self.label).start() self._started = True diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 30a3e7b3dc62..746f2130fb34 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -782,6 +782,7 @@ class TrainingArguments: "Following options are supported:\n" "- pp_first. the topo order is dp, pp, sharding, mp \n" "- sharding_first. the topo order is dp, sharding, pp, mp \n" + "- mp_first. the topo order is mp, pp, dp, sharding\n" "Default is None, for pp_first" ) }, @@ -1393,6 +1394,11 @@ def is_segment_parallel_supported(): order = ["dp", "sharding", "pp", "sep", "mp"] else: order = ["dp", "sharding", "pp", "mp"] + if self.hybrid_parallel_topo_order == "mp_first": + if is_segment_parallel_supported(): + order = ["sep", "mp", "pp", "dp", "sharding"] + else: + order = ["mp", "pp", "dp", "sharding"] if self.use_expert_parallel: order = order[1:-1] + ["dp", "mp"] @@ -1759,6 +1765,13 @@ def is_segment_parallel_supported(): self.pipeline_parallel_degree, self.tensor_parallel_degree, ] + elif self.hybrid_parallel_topo_order == "mp_first": + order = ["mp", "pp", "dp"] + degree = [ + self.tensor_parallel_degree, + self.pipeline_parallel_degree, + self.dataset_world_size, + ] if sep_degree > 1: order.insert(-1, "sep") degree.insert(-1, sep_degree) @@ -1774,6 +1787,8 @@ def is_segment_parallel_supported(): logger.warning( "Currently using sharding_first topo order, but pp_first is recommended when using experts parallel for performance." ) + elif self.hybrid_parallel_topo_order == "mp_first": + order = ["mp", "pp", "dp", "sharding", "sep"] strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { @@ -2046,7 +2061,7 @@ def _post_init_parallel_degree(self): if self.hybrid_parallel_topo_order is None: self.hybrid_parallel_topo_order = "pp_first" - assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"] + assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first", "mp_first"] if self.use_hybrid_parallel and self.enable_auto_parallel: self.use_hybrid_parallel = False diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index dc2cd606fbb4..d46ddfb3164e 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -378,6 +378,8 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): mp_moe = getattr(tensor, "mp_moe", False) if key in tp_actions and not mp_moe: # Get tensor size + if tensor.place.is_cuda_pinned_place(): + tensor = tensor._copy_to(paddle.CUDAPlace(int(paddle.get_device().split(":")[1])),False) tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks if tensor_bytes >= 5 * 1024 * 1024 * 1024: # temporarily set 5GB as threshold tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst) diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index 11b9acf8267f..667d25e8fdcd 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -294,7 +294,7 @@ def get_diff_keys(self, return_all_diff: bool = False) -> List[str]: return all_diff_keys -def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2): +def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2, **kwargs): """ [A1 B1],[A2 B2] => [A1, A2, B1, B2] @@ -330,7 +330,7 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2): else: tensor = paddle.concat([reorder[i] for i in index], axis=axis) - if tensor.place.is_gpu_place(): + if tensor.place.is_gpu_place() and not kwargs.get("keep_on_gpu", False): tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) return tensor @@ -419,7 +419,7 @@ def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tens return np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis) -def normal_fuse_merge_tp(weight_list, is_column=True): +def normal_fuse_merge_tp(weight_list, is_column=True, **kwargs): """ [A1],[A2] => [A1, A2] @@ -437,7 +437,7 @@ def normal_fuse_merge_tp(weight_list, is_column=True): return np.concatenate(weight_list, axis=-1) else: tensor = paddle.concat(weight_list, axis=-1) - if tensor.place.is_gpu_place(): + if tensor.place.is_gpu_place() and not kwargs.get("keep_on_gpu", False): tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) return tensor else: @@ -445,7 +445,7 @@ def normal_fuse_merge_tp(weight_list, is_column=True): return np.concatenate(weight_list, axis=0) else: tensor = paddle.concat(weight_list, axis=0) - if tensor.place.is_gpu_place(): + if tensor.place.is_gpu_place() and not kwargs.get("keep_on_gpu", False): tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) return tensor @@ -720,21 +720,22 @@ def fn( is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False, + **kwargs, ): if x is None: return None if is_naive_2fuse: - return naive_fuse_merge_tp(x, is_column=is_column, fuse_tensor_parts=2) + return naive_fuse_merge_tp(x, is_column=is_column, fuse_tensor_parts=2, **kwargs) elif is_naive_3fuse: - return naive_fuse_merge_tp(x, is_column=is_column, fuse_tensor_parts=3) + return naive_fuse_merge_tp(x, is_column=is_column, fuse_tensor_parts=3, **kwargs) else: - x = normal_fuse_merge_tp(x, is_column=is_column) + x = normal_fuse_merge_tp(x, is_column=is_column, **kwargs) if is_old_qkv: assert is_column, "QKV tensor should be column parallel linear." assert num_attention_heads is not None, "is_old_qkv need num_attention_heads" - x = tensor_parallel_qkv_to_naive_merged_qkv(x, num_attention_heads) + x = tensor_parallel_qkv_to_naive_merged_qkv(x, num_attention_heads, **kwargs) if transpose: x = np.transpose(x, [1, 0]) @@ -744,7 +745,7 @@ def fn( def get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None): - def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False): + def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False, **kwargs): if x is None: return None if transpose: @@ -755,17 +756,17 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=Fals if is_old_qkv: assert is_column, "QKV tensor should be column parallel linear." assert num_attention_heads is not None, "is_old_qkv need num_attention_heads" - x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads) + x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads, **kwargs) if is_naive_2fuse: return naive_fuse_split_tp( - x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2 + x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2, **kwargs ) if is_naive_3fuse: return naive_fuse_split_tp( - x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3 + x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3, **kwargs ) - return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column) + return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, **kwargs) return fn diff --git a/tests/transformers/test_refined_recompute.py b/tests/transformers/test_refined_recompute.py deleted file mode 100644 index 4217c82791bf..000000000000 --- a/tests/transformers/test_refined_recompute.py +++ /dev/null @@ -1,628 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -environment_variables = { - "NVIDIA_TF32_OVERRIDE": "0", - "FLAGS_embedding_deterministic": "1", - "FLAGS_cudnn_deterministic": "1", -} -for k, v in environment_variables.items(): - os.environ[k] = v -import unittest -from typing import Optional, Tuple - -import paddle -import paddle.device -import paddle.nn as nn -import paddle.nn.functional as F -from paddle.distributed.fleet.recompute import recompute as original_recompute - -from paddlenlp.trainer.training_args import TrainingArguments -from paddlenlp.transformers.refined_recompute import no_recompute as rr_no_recompute -from paddlenlp.transformers.refined_recompute import recompute as rr_recompute -from paddlenlp.utils.import_utils import is_paddle_cuda_available - -ACT2FN = { - "relu": F.relu, - "gelu": F.gelu, - "tanh": F.tanh, - "sigmoid": F.sigmoid, -} -dtype = paddle.float16 - - -class PyLayerMatmul(paddle.autograd.PyLayer): - @staticmethod - def forward(ctx, a, b): - ctx.save_for_backward(a, b) - return a @ b - - @staticmethod - def backward(ctx, dy): - a, b = ctx.saved_tensor() - if hasattr(a, "main_grad"): - a.main_grad.add_(paddle.ones_like(a.main_grad)) - if hasattr(b, "main_grad"): - b.main_grad.add_(paddle.ones_like(b.main_grad)) - grad_a = paddle.matmul(dy, b, transpose_y=True) - grad_b = paddle.matmul(a, dy, transpose_x=True) - return grad_a, grad_b - - -pylayer_matmul = PyLayerMatmul.apply - - -class BertConfig: - def __init__( - self, - vocab_size: int = 30522, - hidden_size: int = 768, - num_hidden_layers: int = 4, - num_attention_heads: int = 12, - intermediate_size: int = 3072, - hidden_act: str = "gelu", - hidden_dropout_prob: float = 0.0, - attention_probs_dropout_prob: float = 0.0, - max_position_embeddings: int = 1024, - type_vocab_size: int = 2, - initializer_range: float = 0.2, - pad_token_id: int = 0, - pool_act: str = "tanh", - layer_norm_eps: float = 1e-12, - output_attentions: bool = False, - output_hidden_states: bool = False, - num_labels=2, - recompute=False, - use_rr_recompute=False, - recompute_use_reentrant=False, - **kwargs - ): - self.pad_token_id = pad_token_id - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.pool_act = pool_act - self.layer_norm_eps = layer_norm_eps - self.output_attentions = output_attentions - self.output_hidden_states = output_hidden_states - self.num_labels = num_labels - self.recompute = recompute - self.use_rr_recompute = use_rr_recompute - self.recompute_use_reentrant = recompute_use_reentrant - - -class BertEmbeddings(nn.Layer): - """Construct the embeddings from word, position and token_type embeddings.""" - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - self.register_buffer( - "position_ids", paddle.arange(config.max_position_embeddings, dtype="int64").reshape((1, -1)) - ) - - def forward( - self, - input_ids: Optional[paddle.Tensor] = None, - token_type_ids: Optional[paddle.Tensor] = None, - position_ids: Optional[paddle.Tensor] = None, - ) -> paddle.Tensor: - input_shape = input_ids.shape - seq_length = input_ids.shape[1] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if token_type_ids is None: - token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64) - - inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - position_embeddings = self.position_embeddings(position_ids) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = config.hidden_size // config.num_attention_heads - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[paddle.Tensor]: - - reshape_fn = lambda x: x.reshape([0, 0, -1, self.attention_head_size]) - # compute q,k,v - query_layer = reshape_fn(self.query(hidden_states)) - key_layer = reshape_fn(self.key(hidden_states)) - value_layer = reshape_fn(self.value(hidden_states)) - - context_layer = rr_no_recompute( - F.scaled_dot_product_attention, - query=query_layer, - key=key_layer, - value=value_layer, - is_causal=True, - enable=self.config.use_rr_recompute and self.config.recompute, - ) - - new_context_layer_shape = context_layer.shape[:-2] + [ - self.all_head_size, - ] - context_layer = context_layer.reshape(new_context_layer_shape) - - outputs = (context_layer, None) if output_attentions else (context_layer,) - - return outputs - - -class BertSelfOutput(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor: - hidden_states = rr_no_recompute( - self.dense, hidden_states, enable=self.config.use_rr_recompute and self.config.recompute - ) - hidden_states = self.dropout(hidden_states) - - hidden_states = rr_no_recompute( - self.LayerNorm, hidden_states + input_tensor, enable=self.config.use_rr_recompute and self.config.recompute - ) - return hidden_states - - -class BertAttention(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.self = BertSelfAttention(config) - self.output = BertSelfOutput(config) - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[paddle.Tensor]: - self_outputs = self.self( - hidden_states, - attention_mask, - output_attentions, - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -class BertIntermediate(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=False) - self.dense.weight.main_grad = paddle.zeros_like(self.dense.weight).cast("float32") - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: - def pylayer_dense(hidden_states): - return pylayer_matmul(hidden_states, self.dense.weight) - - hidden_states = rr_no_recompute( - pylayer_dense, hidden_states, enable=self.config.use_rr_recompute and self.config.recompute - ) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - -class BertOutput(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor: - def custom_dense(hidden_states, weight, bias=None): - return F.linear(hidden_states, weight, bias) - - bias = self.dense.bias * 1.1 - hidden_states = rr_no_recompute( - custom_dense, - hidden_states, - weight=self.dense.weight, - bias=bias, - enable=self.config.use_rr_recompute and self.config.recompute, - keys_ignore_to_save=["bias"], - ) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertLayer(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.seq_len_dim = 1 - self.attention = BertAttention(config) - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[paddle.Tensor]: - # self attn - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - attention_output = self_attention_outputs[0] - - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - # ffn - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - - outputs = (layer_output,) + outputs - - return outputs - - -class BertEncoder(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.LayerList([BertLayer(config) for _ in range(config.num_hidden_layers)]) - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = False, - ) -> Tuple[paddle.Tensor]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for layer_module in self.layer: - # add hidden_states - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.training and self.config.recompute: - recompute_function = rr_recompute if self.config.use_rr_recompute else original_recompute - layer_outputs = recompute_function( - layer_module, - hidden_states, - attention_mask, - output_attentions, - use_reentrant=self.config.recompute_use_reentrant, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - output_attentions, - ) - hidden_states = layer_outputs[0] - - # add self attn - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return tuple( - v - for v in [ - hidden_states, - all_hidden_states, - all_self_attentions, - ] - if v is not None - ) - - -class BertPreTrainedModel(nn.Layer): - def _init_weights(self, module): - """Initialize the weights""" - pass - - -class BertModel(BertPreTrainedModel): - def __init__(self, config): - super().__init__() - self.config = config - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - - def forward( - self, - input_ids: Optional[paddle.Tensor] = None, - attention_mask: Optional[paddle.Tensor] = None, - token_type_ids: Optional[paddle.Tensor] = None, - position_ids: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) -> Tuple[paddle.Tensor]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if token_type_ids is None: - token_type_ids = paddle.zeros(input_ids.shape, dtype=paddle.int64) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - return encoder_outputs - - -class BertRefinedRecomputeTest(unittest.TestCase): - def no_pp_fwd_bwd( - self, - recompute=False, - use_rr_recompute=False, - recompute_use_reentrant=False, - num_hidden_layers=4, - shape=[2, 64], - ): - paddle.set_default_dtype(dtype) - paddle.seed(42) - config = BertConfig( - num_hidden_layers=num_hidden_layers, - recompute=recompute, - use_rr_recompute=use_rr_recompute, - recompute_use_reentrant=recompute_use_reentrant, - ) - model = BertModel(config) - model.train() - input_ids = paddle.randint(10, config.vocab_size, shape=shape) - gpu_mem_used_before = paddle.device.cuda.memory_allocated() - outputs = model(input_ids=input_ids)[0] - gpu_mem_used_after = paddle.device.cuda.memory_allocated() - outputs.sum().backward() - - # div = 1024**3 # GB - div = 1 # KB - return ( - model, - round((gpu_mem_used_after - gpu_mem_used_before) / div, 2), - round(paddle.device.cuda.max_memory_allocated() / div, 2), - ) - - @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute only support on gpu") - def test_refined_recompute(self): - raw_dtype = paddle.get_default_dtype() - - model1, mem_usage_forward1, max_mem_usage_forward1 = self.no_pp_fwd_bwd( - recompute=True, use_rr_recompute=False - ) # with recompute - model2, mem_usage_forward2, max_mem_usage_forward2 = self.no_pp_fwd_bwd( - recompute=True, use_rr_recompute=True - ) # with rr recompute - model3, mem_usage_forward3, max_mem_usage_forward3 = self.no_pp_fwd_bwd( - recompute=False, use_rr_recompute=False - ) # without recompute - - name_list = [n for n, _ in model1.named_parameters()] - - for param1, param2, name in zip(model1.parameters(), model3.parameters(), name_list): - # test main grad - if "intermediate.dense.weight" in name: - self.assertTrue(param1.main_grad.sum().item() > 0) - self.assertTrue(param2.main_grad.sum().item() > 0) - self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) - - for param1, param2, name in zip(model2.parameters(), model3.parameters(), name_list): - # test main grad - if "intermediate.dense.weight" in name: - self.assertTrue(param1.main_grad.sum().item() > 0) - self.assertTrue(param2.main_grad.sum().item() > 0) - self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) - - # self.assertTrue(mem_usage_forward1 < mem_usage_forward2 < mem_usage_forward3) - # self.assertTrue(max_mem_usage_forward1 < max_mem_usage_forward2 < max_mem_usage_forward3) - - del model1, model2, model3 - paddle.device.cuda.empty_cache() - paddle.set_default_dtype(raw_dtype) - - def pp_fwd_bwd( - self, - recompute=False, - use_rr_recompute=False, - recompute_use_reentrant=False, - num_iter=4, - shape=[2, 64], - ): - paddle.set_default_dtype(dtype) - paddle.seed(42) - config = BertConfig( - num_hidden_layers=1, - recompute=recompute, - use_rr_recompute=use_rr_recompute, - recompute_use_reentrant=recompute_use_reentrant, - ) - layer = BertLayer(config) - layer.train() - - x = paddle.randn([*shape, config.hidden_size]) - x.stop_gradient = False - x_copy = x - - if layer.training and config.recompute: - recompute_function = rr_recompute if config.use_rr_recompute else original_recompute - for _ in range(num_iter): - x = recompute_function(layer, x, use_reentrant=config.recompute_use_reentrant)[0] - else: - for _ in range(num_iter): - x = layer(x)[0] - - x.sum().backward() - - return x_copy.grad, layer - - @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") - def test_refined_recompute_pp(self): - paddle.set_device("gpu") - raw_dtype = paddle.get_default_dtype() - grad1, layer1 = self.pp_fwd_bwd(recompute=True, use_rr_recompute=False) - grad2, layer2 = self.pp_fwd_bwd(recompute=True, use_rr_recompute=True) - grad3, layer3 = self.pp_fwd_bwd(recompute=False, use_rr_recompute=False) - - name_list = [n for n, _ in layer1.named_parameters()] - - for param1, param2, name in zip(layer1.parameters(), layer3.parameters(), name_list): - # test main grad - if "intermediate.dense.weight" in name: - self.assertTrue(param1.main_grad.sum().item() > 0) - self.assertTrue(param2.main_grad.sum().item() > 0) - self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) - - self.assertTrue(paddle.equal_all(grad1.cast("float32"), grad3.cast("float32"))) - for param1, param2, name in zip(layer2.parameters(), layer3.parameters(), name_list): - # test main grad - if "intermediate.dense.weight" in name: - self.assertTrue(param1.main_grad.sum().item() > 0) - self.assertTrue(param2.main_grad.sum().item() > 0) - self.assertTrue(paddle.equal_all(param1.grad.cast("float32"), param2.grad.cast("float32"))) - - self.assertTrue(paddle.equal_all(grad2.cast("float32"), grad3.cast("float32"))) - - del grad1, grad2, grad3 - del layer1, layer2, layer3 - paddle.device.cuda.empty_cache() - paddle.set_default_dtype(raw_dtype) - - -class TestRefinedRecomputeModel(unittest.TestCase): - def setUp(self): - self.args = TrainingArguments( - output_dir="./", - do_train=True, - max_steps=100, - tensor_parallel_degree=1, - pipeline_parallel_degree=1, - refined_recompute="attention_column_ln:1,attention_row_ln:2,flash_attn:-1,mlp_column_ln:2,mlp_row_ln:-1", - ) - - @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") - def test_llama_refined_recompute(self): - paddle.set_device("gpu") - from paddlenlp.transformers.llama import LlamaConfig, LlamaModel - - llama_model = "__internal_testing__/tiny-random-llama" - config = LlamaConfig.from_pretrained(llama_model) - config.recompute = True - config.recompute_granularity = "full" - config.recompute_use_reentrant = False - config.sequence_parallel = False - config.use_flash_attention = True - config.refined_recompute = self.args.refined_recompute - model = LlamaModel.from_config(config=config, dtype="bfloat16") - input_ids = paddle.randint(0, 100, shape=[1, 1024], dtype="int64") - output = model(input_ids) - output[0].mean().backward() - - @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") - def test_qwen_refined_recompute(self): - paddle.set_device("gpu") - from paddlenlp.transformers.qwen import QWenConfig, QWenModel - - llama_model = "__internal_testing__/tiny-random-qwen" - config = QWenConfig.from_pretrained(llama_model) - config.recompute = True - config.recompute_granularity = "full" - config.recompute_use_reentrant = False - config.sequence_parallel = False - config.use_flash_attention = True - config.refined_recompute = self.args.refined_recompute - config.seq_length = 1024 - model = QWenModel.from_config(config=config, dtype="bfloat16") - input_ids = paddle.randint(0, 100, shape=[1, 1024], dtype="int64") - output = model(input_ids) - output[0].mean().backward() - - @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") - def test_qwen2_refined_recompute(self): - paddle.set_device("gpu") - from paddlenlp.transformers.qwen2 import Qwen2Config, Qwen2Model - - llama_model = "__internal_testing__/tiny-random-qwen2" - config = Qwen2Config.from_pretrained(llama_model) - config.recompute = True - config.recompute_granularity = "full" - config.recompute_use_reentrant = False - config.sequence_parallel = False - config.use_flash_attention = True - config.refined_recompute = self.args.refined_recompute - model = Qwen2Model.from_config(config=config, dtype="bfloat16") - input_ids = paddle.randint(0, 100, shape=[1, 1024], dtype="int64") - output = model(input_ids) - output[0].mean().backward()