Skip to content

[Auto-parallel] Improve usability of auto_dy pipeline parallel #10920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import paddle.distributed.auto_parallel.intermediate.parallelize as parallelize
import paddle.nn as nn
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.pipelining.schedules import get_pp_schedule
from paddle.profiler.utils import switch_job_schedule_profiler
from tqdm.auto import tqdm

Expand All @@ -48,9 +49,7 @@
ShardingOption,
TrainOutput,
_exec_mode_guard,
check_auto_parallel_pipeline_support,
get_last_checkpoint,
get_pp_schedule,
has_length,
speed_metrics,
)
Expand Down Expand Up @@ -81,7 +80,6 @@ def loss_func(loss, outputs):
kwargs.update({"criterion": loss_func})
self.auto_dist_config = kwargs.pop("auto_dist_config", None)
model = kwargs.get("model", None)
self.model_type = kwargs.pop("model_type", None)
assert model is not None
if kwargs.get("args", None) is not None and kwargs["args"].use_intermediate_api:
if not parallelize.has_parallelized_model:
Expand All @@ -103,16 +101,19 @@ def loss_func(loss, outputs):

self.global_mesh = fleet.auto.get_mesh()
self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group()
if self.args.pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support(self.model_type):
if self.args.pipeline_parallel_degree > 1:
self.pp_schedule = get_pp_schedule(
model,
self.model_type,
self.args.n_microbatches,
self.args.gradient_accumulation_steps,
self.criterion,
self.args.pipeline_schedule_mode,
self.args.pipeline_parallel_degree,
self.comm_group_in_pp,
)
self.args.per_device_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps
)
self.args.gradient_accumulation_steps = 1
self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]

@classmethod
Expand Down Expand Up @@ -762,6 +763,7 @@ def compute_pipeline_loss(self, model, inputs, return_outputs=False):

final_loss = None
if len(losses) != 0:
losses = [loss[0] for loss in losses]
final_loss = paddle.stack(losses).mean()

return final_loss
Expand All @@ -770,16 +772,13 @@ def dynamic_auto_parallel_pipeline_training(
self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]
) -> paddle.Tensor:
assert self.args.pipeline_parallel_degree > 1, "pipeline_parallel_degree must be greater than 1."
assert check_auto_parallel_pipeline_support(
self.model_type
), "dynamic auto_parallel pipeline only supports special models"
with self.autocast_smart_context_manager():
loss = self.compute_pipeline_loss(model, inputs)

return loss

def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
if self.args.pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support(self.model_type):
if self.args.pipeline_parallel_degree > 1:
return self.dynamic_auto_parallel_pipeline_training(model, inputs)
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)
Expand Down
14 changes: 0 additions & 14 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from paddlenlp.ops import Topology

from ..trainer.argparser import strtobool
from ..transformers import get_gpt_pp_schedule, get_llama_pp_schedule
from ..transformers.tokenizer_utils_base import BatchEncoding
from ..utils.env import PREFIX_CHECKPOINT_DIR, _re_checkpoint # noqa for compatibility
from ..utils.fault_tolerance import PDC_DOWNLOAD_ERROR
Expand Down Expand Up @@ -1257,19 +1256,6 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
)


def check_auto_parallel_pipeline_support(model_type=None):
support_types = ["llama_pp", "gpt_pp"]
return model_type in support_types


def get_pp_schedule(model, model_type, n_microbatches, loss_fn, mode, pp_degree, group):
assert check_auto_parallel_pipeline_support(model_type)
if model_type == "llama_pp":
return get_llama_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group)
elif model_type == "gpt_pp":
return get_gpt_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group)


def parse_nccl_config_file(config_dir):
json_file = Path(config_dir)
if json_file.exists():
Expand Down
Loading