Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddleformers/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
"TrainerState",
"DEFAULT_PROGRESS_CALLBACK",
"TrainerCallback",
"StepFlexToken",
"FP8QuantWeightCallback",
],
"trainer_utils": [
"get_last_checkpoint",
Expand Down
43 changes: 23 additions & 20 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
RowParallelQuantizationLinear,
)

try:
from ..quantization.quantization_linear import QuantizationLinear
except:
QuantizationLinear = None
try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
register_sequence_parallel_allreduce_hooks,
Expand Down Expand Up @@ -150,6 +154,7 @@
TrainerState,
)
from .trainer_utils import ( # set_hyrbid_parallel_seed,
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
EvalPrediction,
IntervalStrategy,
Expand Down Expand Up @@ -201,6 +206,15 @@
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"

OPTIMIZER_NAME = "optimizer.pdopt"
SCHEDULER_NAME = "scheduler.pdparams"
SCALER_NAME = "scaler.pdparams"


if is_datasets_available():
import datasets

Expand Down Expand Up @@ -535,10 +549,7 @@ def _wrap_amp_model(self, args, model):
level=self.args.fp16_opt_level,
dtype=self.amp_dtype,
excluded_layers=[
QuantizationLinear,
ColumnParallelQuantizationLinear,
RowParallelQuantizationLinear,
QuantizationLoRABaseLinear,
QuantizationLinear
]
+ self._decorate_exclude_layers(model),
)
Expand Down Expand Up @@ -1773,6 +1784,7 @@ def get_train_dataloader(self):
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
prefetch_factor=32,
**additional_configs,
)
else:
Expand All @@ -1785,6 +1797,7 @@ def get_train_dataloader(self):
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
prefetch_factor=32,
**additional_configs,
)

Expand Down Expand Up @@ -2210,10 +2223,7 @@ def _wrap_model(self, model, training=True):
level=self.args.fp16_opt_level,
dtype=self.amp_dtype,
excluded_layers=[
QuantizationLinear,
ColumnParallelQuantizationLinear,
RowParallelQuantizationLinear,
QuantizationLoRABaseLinear,
QuantizationLinear
]
+ self._decorate_exclude_layers(model),
)
Expand Down Expand Up @@ -2328,11 +2338,7 @@ def get_expected_keys(inputs, keys):
assert (
ShardingOption.SHARD_GRAD_OP in self.args.sharding or ShardingOption.SHARD_OP in self.args.sharding
), "Only support tensor parallel + sharding stage1/stage2 hybrid parallel now."
# NOTE: TensorParallel will be called in distributed_model when sharding stage1, so no need to call here
if ShardingOption.SHARD_GRAD_OP in self.args.sharding:
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
model, hcg, strategy=fleet.fleet._user_defined_strategy
)
model = paddle.distributed.fleet.meta_parallel.TensorParallel(model, hcg, strategy=None)

if ShardingOption.SHARD_OP in self.args.sharding:
if self.args.amp_master_grad:
Expand Down Expand Up @@ -2775,7 +2781,7 @@ def _save_checkpoint(self, model, metrics=None):

# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")

if self.args.unified_checkpoint and self.args.offload_optim:
Expand Down Expand Up @@ -2863,9 +2869,6 @@ def _save_checkpoint(self, model, metrics=None):
):
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))

if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
self._offload_optimizer()

self.runtime_timer.stop()

# Maybe delete some older checkpoints.
Expand Down Expand Up @@ -3125,7 +3128,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
opt_state_dict = None
if self.args.should_load_sharding_stage1_model:
opt_state_dict = self.sharding_io.load_optimizer_state_with_reshard(
checkpoint, PADDLE_OPTIMIZER_NAME, self.model_wrapped
checkpoint, OPTIMIZER_NAME, self.model_wrapped
)
else:
use_unified_checkpoint = False
Expand All @@ -3137,7 +3140,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):

if not use_unified_checkpoint:
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(checkpoint, optimizer_name)
if os.path.isfile(path):
opt_state_dict = paddle.load(path)
Expand Down Expand Up @@ -3180,7 +3183,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# Load in optimizer and scheduler states
self.optimizer.set_state_dict(opt_state_dict)
else:
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")
gc.collect()
empty_device_cache()
Expand Down
66 changes: 66 additions & 0 deletions paddleformers/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
"""
import dataclasses
import json
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import numpy as np
from tqdm.auto import tqdm

from paddleformers.transformers.moe_utils import offload, reload
from ..utils.log import logger
from .trainer_utils import IntervalStrategy, has_length
from .training_args import TrainingArguments
Expand All @@ -39,6 +41,8 @@
"ProgressCallback",
"PrinterCallback",
"EarlyStoppingCallback",
"StepFlexToken",
"FP8QuantWeightCallback",
]


Expand Down Expand Up @@ -608,3 +612,65 @@ def on_evaluate(self, args, state, control, metrics, **kwargs):
self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True


class StepFlexToken(TrainerCallback):
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
model = kwargs.pop("model")
if hasattr(model, "step_flex_token"):
model.step_flex_token(state.global_step)


g_shard_bypass_dygraph_optimizer = int(os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0))


def enable_in_dict_config(config, key):
"""enable_in_dict_config"""
return key in config and config[key]


skip_count = 0


class FP8QuantWeightCallback(TrainerCallback):
"""
FP8QuantWeightCallback
"""

def on_step_begin(self, args, state, control, **kwargs):
"""
每个step开始前把专家参数quant成fp8q
"""
model = kwargs["model"]
optimizer = kwargs["optimizer"]
global skip_count

if not g_shard_bypass_dygraph_optimizer or skip_count == 0:
model.fp8_quant_weight(True)
optimizer.clear_param_storage("moe_expert")
optimizer.clear_param_storage("rms_linear")
optimizer.clear_param_storage("memory_attn")
optimizer.clear_param_storage("attn_out_project")
optimizer.clear_param_storage("shared_expert")

self.moe_weights_name = []
for param in optimizer._inner_opt._parameter_list:
color = getattr(param, "color", -1)
if isinstance(color, dict) and color["color"] == "moe_expert":
self.moe_weights_name.append(param.name)

for name in self.moe_weights_name:
offload(optimizer._master_weights[name])

skip_count += 1

def on_optimizer_begin(self, args, state, control, **kwargs):
optimizer = kwargs["optimizer"]
for name in self.moe_weights_name:
reload(optimizer._master_weights[name])
54 changes: 22 additions & 32 deletions paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from paddle.distributed import fleet

from ..utils.env import PREFIX_CHECKPOINT_DIR
from ..utils.fault_tolerance import is_ft_env
from ..utils.log import logger
from ..utils.pdc_sdk import FLASH_DEVICE
from .trainer_utils import (
Expand Down Expand Up @@ -1397,12 +1398,7 @@ def is_segment_parallel_supported():
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1:
order.insert(-1, "ep")
sd_idx = order.index("sharding")
# if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"]
# if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"]
order.insert(sd_idx, "moe_sharding")
order = order[1:-1] + ["dp", "mp"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要删除,删除之后会不会对原来逻辑有影响

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不修改的话会报错

  File "/PaddleFormers/paddleformers/trainer/training_args.py", line 1561, in __post_init__
    self.add_moe_comm_group()
  File "/PaddleFormers/paddleformers/trainer/training_args.py", line 2071, in add_moe_comm_group
    sharding_parallel_groups = topo.get_comm_list("sharding")
  File "/py3.10/lib/python3.10/site-packages/paddle/distributed/fleet/base/topology.py", line 227, in get_comm_list
    assert axis_name in self._parallel_names
AssertionError


if is_segment_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1545,6 +1541,12 @@ def is_segment_parallel_supported():
assert (
"split_param" in sharding_parallel_config
), "split_param should be set when enable_stage1_allgather_overlap."
use_casual_mask = os.getenv("USE_CASUAL_MASK", "False")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

assert use_casual_mask, "enable_stage1_allgather_overlap requires USE_CASUAL_MASK=True."
assert self.logging_steps > 1, (
"The logging_steps should be greater than 1 for enable_stage1_allgather_overlap, "
f"but got logging_steps={self.logging_steps}."
)

if "split_param" in sharding_parallel_config:
if ShardingOption.SHARD_OP not in self.sharding:
Expand All @@ -1556,6 +1558,9 @@ def is_segment_parallel_supported():
fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

if self.expert_parallel_degree > 1:
self.add_moe_comm_group()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉的话会报错

 File "/lib/python3.10/site-packages/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py", line 79, in build_layer
    return self.layer_func(*self.inputs, **{**self.kwargs, **extra_kwargs})
  File "/PaddleFormers/paddleformers/transformers/deepseek_v2/modeling.py", line 2275, in __init__
    DeepseekV2MoE(
  File "/PaddleFormers/paddleformers/transformers/deepseek_v2/modeling.py", line 1018, in __init__
    super().__init__(
  File "/PaddleFormers/paddleformers/transformers/moe_layer.py", line 225, in __init__
    self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group
AttributeError: 'HybridCommunicateGroup' object has no attribute 'expert_parallel_group'. Did you mean: 'get_data_parallel_group'?


elif self.enable_auto_parallel:
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
Expand Down Expand Up @@ -1902,32 +1907,12 @@ def is_segment_parallel_supported():
self.refined_recompute = refined_recompute_dict

# process fault tolerance settings
pdc_zcc_init_step = os.getenv("PDC_FC_INIT_STEP")
if pdc_zcc_init_step is not None and int(pdc_zcc_init_step) > 0:
self.resume_from_checkpoint = os.path.join(FLASH_DEVICE, f"{PREFIX_CHECKPOINT_DIR}-{pdc_zcc_init_step}")
logger.warning(
f"PDC_FC_INIT_STEP {pdc_zcc_init_step} has been specified, automatically resume from FLASH_DEVICE: {self.resume_from_checkpoint}"
)
if self.flash_device_save_steps > 0:
assert (
self.enable_zero_cost_checkpoint
), "flash_device_save_steps should only be set in zero cost checkpoint save mode with flash device mounted."

if self.enable_zero_cost_checkpoint:
assert (
"enable_fuse_optimizer_states" in sharding_parallel_config
), "zero cost checkpoint must be used when enable_fuse_optimizer_states is enabled in sharding parallel config"

assert (
self.flash_device_save_steps % self.zcc_ema_interval == 0
), f"flash_device_save_steps[{self.flash_device_save_steps}] must be divisible by zcc_ema_interval[{self.zcc_ema_interval}]"
assert (
self.save_steps % self.zcc_ema_interval == 0
), f"save_steps[{self.save_steps}] must be divisible by zcc_ema_interval[{self.zcc_ema_interval}]"
if self.zcc_save_ema_coef is not None:
assert (
self.zcc_workers_num == 1
), "EMA function in zero cost checkpoint mode does not support zcc_workers_num > 1 for now."
if not is_ft_env():
if self.pdc_download_ckpt:
logger.warning(
"pdc_download_ckpt can only be set as true inside FT environment. Automatically disable it now."
)
self.pdc_download_ckpt = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,这里是nlp的修改,我们的处于较旧版本,应该和paddleformers里的新版本保持一致。建议先不删进行验证。


def _post_init_parallel_degree(self):
self.use_hybrid_parallel = False
Expand Down Expand Up @@ -1994,6 +1979,11 @@ def _post_init_parallel_degree(self):
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")
self.sharding = []

if sharding_parallel_degree > 1:
assert (
sharding_parallel_degree % expert_parallel_degree == 0
), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}."

self.data_parallel_degree = world_size // (
sharding_parallel_degree
* tensor_parallel_degree
Expand Down
Loading
Loading