diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index acfea513a1e2..dfed602e0956 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -24,6 +24,7 @@ import paddle.distributed.auto_parallel.intermediate.parallelize as parallelize import paddle.nn as nn from paddle.distributed import fleet +from paddle.distributed.auto_parallel.pipelining.schedules import get_pp_schedule from paddle.profiler.utils import switch_job_schedule_profiler from tqdm.auto import tqdm @@ -48,9 +49,7 @@ ShardingOption, TrainOutput, _exec_mode_guard, - check_auto_parallel_pipeline_support, get_last_checkpoint, - get_pp_schedule, has_length, speed_metrics, ) @@ -81,7 +80,6 @@ def loss_func(loss, outputs): kwargs.update({"criterion": loss_func}) self.auto_dist_config = kwargs.pop("auto_dist_config", None) model = kwargs.get("model", None) - self.model_type = kwargs.pop("model_type", None) assert model is not None if kwargs.get("args", None) is not None and kwargs["args"].use_intermediate_api: if not parallelize.has_parallelized_model: @@ -103,16 +101,19 @@ def loss_func(loss, outputs): self.global_mesh = fleet.auto.get_mesh() self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group() - if self.args.pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support(self.model_type): + if self.args.pipeline_parallel_degree > 1: self.pp_schedule = get_pp_schedule( model, - self.model_type, - self.args.n_microbatches, + self.args.gradient_accumulation_steps, self.criterion, self.args.pipeline_schedule_mode, self.args.pipeline_parallel_degree, self.comm_group_in_pp, ) + self.args.per_device_train_batch_size = ( + self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps + ) + self.args.gradient_accumulation_steps = 1 self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"] @classmethod @@ -762,6 +763,7 @@ def compute_pipeline_loss(self, model, inputs, return_outputs=False): final_loss = None if len(losses) != 0: + losses = [loss[0] for loss in losses] final_loss = paddle.stack(losses).mean() return final_loss @@ -770,16 +772,13 @@ def dynamic_auto_parallel_pipeline_training( self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]] ) -> paddle.Tensor: assert self.args.pipeline_parallel_degree > 1, "pipeline_parallel_degree must be greater than 1." - assert check_auto_parallel_pipeline_support( - self.model_type - ), "dynamic auto_parallel pipeline only supports special models" with self.autocast_smart_context_manager(): loss = self.compute_pipeline_loss(model, inputs) return loss def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: - if self.args.pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support(self.model_type): + if self.args.pipeline_parallel_degree > 1: return self.dynamic_auto_parallel_pipeline_training(model, inputs) with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index d8d88d1cd4ad..edc36fd8a2f6 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -44,7 +44,6 @@ from paddlenlp.ops import Topology from ..trainer.argparser import strtobool -from ..transformers import get_gpt_pp_schedule, get_llama_pp_schedule from ..transformers.tokenizer_utils_base import BatchEncoding from ..utils.env import PREFIX_CHECKPOINT_DIR, _re_checkpoint # noqa for compatibility from ..utils.fault_tolerance import PDC_DOWNLOAD_ERROR @@ -1257,19 +1256,6 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout): ) -def check_auto_parallel_pipeline_support(model_type=None): - support_types = ["llama_pp", "gpt_pp"] - return model_type in support_types - - -def get_pp_schedule(model, model_type, n_microbatches, loss_fn, mode, pp_degree, group): - assert check_auto_parallel_pipeline_support(model_type) - if model_type == "llama_pp": - return get_llama_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group) - elif model_type == "gpt_pp": - return get_gpt_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group) - - def parse_nccl_config_file(config_dir): json_file = Path(config_dir) if json_file.exists():