Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 74 additions & 32 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
get_last_checkpoint,
get_scheduler,
has_length,
init_optimizer,
set_seed,
should_skip_data,
speed_metrics,
Expand Down Expand Up @@ -939,7 +940,7 @@ def train(
self._memory_tracker.start()

if not self.args.enable_auto_parallel:
if not self.args.should_load_sharding_stage1_model:
if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint:
self._load_from_checkpoint(resume_from_checkpoint)

if self.args.should_load_sharding_stage1_model:
Expand All @@ -959,14 +960,31 @@ def train(
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
elif not self.args.using_flex_checkpoint:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
assert self.args.using_flex_checkpoint, "using_flex_checkpoint should be True"
model = self._wrap_model(self.model_wrapped)
if model is not self.model:
self.model_wrapped = model

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
model_sharded_state_dict = self.model.sharded_state_dict()
self.optimizer.sharded_state_dict(model_sharded_state_dict)
init_optimizer(self.optimizer)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint)
self._load_scheduler(resume_from_checkpoint)
else:
model = self.model_wrapped
if delay_optimizer_creation:
Expand Down Expand Up @@ -2735,6 +2753,10 @@ def _save_checkpoint(self, model, metrics=None):
else:
self.save_model(output_dir)

model_sharded_state_dict = self.model.sharded_state_dict()
if self.args.using_flex_checkpoint:
os.makedirs(output_dir, exist_ok=True)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
Expand Down Expand Up @@ -2796,26 +2818,34 @@ def _save_checkpoint(self, model, metrics=None):
self.optimizer,
output_dir,
signal_dir,
self.args.optim_shard_num,
)
else:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
if not self.args.using_flex_checkpoint:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(
os.getenv("FLAG_LLM_PDC", "False")
), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
else:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
Expand All @@ -2835,9 +2865,8 @@ def _save_checkpoint(self, model, metrics=None):
self.optimizer,
output_dir,
signal_dir,
self.args.optim_shard_num,
)
else:
elif not self.args.using_flex_checkpoint:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
Expand All @@ -2851,6 +2880,12 @@ def _save_checkpoint(self, model, metrics=None):
saved_signal_path,
)

else:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))

Expand Down Expand Up @@ -3122,6 +3157,24 @@ def _save(
with open(path, "w") as f:
json.dump(model_meta, f)

def _load_scheduler(self, checkpoint):
if checkpoint is None:
self.runtime_timer.stop()
return

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
self.runtime_timer.start("checkpoint loading time")
Expand Down Expand Up @@ -3197,18 +3250,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
gc.collect()
empty_device_cache()

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)
self._load_scheduler(checkpoint)

if self.args.offload_optim:
logger.info("Offloading optimizer state...")
Expand Down
73 changes: 73 additions & 0 deletions paddleformers/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@
from ..utils.tools import get_env_device
from .utils.helper import distributed_file

try:
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizerV2,
)
except:
DygraphShardingOptimizerV2 = None

try:
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
except:
DygraphShardingOptimizer = None

__all__ = [
"TrainOutput",
"PredictionOutput",
Expand Down Expand Up @@ -1283,3 +1297,62 @@ def _insert_sync(self, sync_var, src, mp_group, sync_mode):
# Move it back to pin memory
if original_device == "pin_memory":
sync_var = paddle.to_tensor(sync_var, place=paddle.CUDAPinnedPlace())


def init_optimizer(optimizer):
"""
Initialize the optimizer's states according to its type.

For DygraphShardingOptimizer (V1), initializes accumulators for local parameters.
For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters.
For other cases, initializes accumulators for all parameters.

Args:
optimizer: The optimizer instance to be initialized.
"""
if DygraphShardingOptimizer is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizer):
local_params = optimizer._rank2params[optimizer._sharding_rank]
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params)
return

elif DygraphShardingOptimizerV2 is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizerV2):

def init_param_optimizer_states(param_iter):
master_weights = {}
state_dict = {}
moments = ("moment1_0", "moment2_0")
betas = ("beta1_pow_acc_0", "beta2_pow_acc_0")
for static_name, shape, no_need_master_weights in param_iter:
if not no_need_master_weights:
master_weights[static_name] = paddle.zeros(shape, dtype="float32")
prefix = f"{static_name}_fp32_master_0_"
else:
prefix = f"{static_name}_"

for moment in moments:
key = f"{prefix}{moment}"
state_dict[key] = paddle.zeros(shape, dtype="float32")
for beta in betas:
key = f"{prefix}{beta}"
state_dict[key] = paddle.zeros((1,), dtype="float32")
return master_weights, state_dict

def buffer_params():
for buffer in optimizer._comm_buffer_list:
for param_name, grad_view in buffer._sharding_param_grad_view.items():
param_begin = grad_view._param_begin
param_end = grad_view._param_end
shape = (param_end - param_begin,)
no_need_master_weights = grad_view._param.dtype == paddle.float32
if shape[0] > 0:
yield param_name, shape, no_need_master_weights

master_weights, state_dict = init_param_optimizer_states(buffer_params())
state_dict["master_weights"] = master_weights
state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06}

optimizer.set_state_dict(state_dict)
return
optimizer._create_accumulators(
paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list
)
14 changes: 14 additions & 0 deletions paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ class TrainingArguments:
Whether to release gradients during training. Default is `False`.
ckpt_quant_stage (`str`, *optional*):
Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0).
using_flex_checkpoint(`bool`, *optional*):
Whether to use FlexCheckpoint for save and load. Default is False.
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
"""

output_dir: str = field(
Expand Down Expand Up @@ -1080,6 +1084,16 @@ class TrainingArguments:
default=False,
metadata={"help": "是否开启单路sharding时global norm通信拆分全局通信组为pp通信和mp通信分别做"},
)
using_flex_checkpoint: Optional[bool] = field(
default=False,
metadata={"help": "Whether use FlexCheckpoint."},
)
aoa_config: Optional[dict[str, list[str]]] = field(
default=None,
metadata={
"help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None."
},
)
convert_from_hf: Optional[bool] = field(
default=False,
metadata={"help": "Load model from HuggingFace safetensors."},
Expand Down
11 changes: 11 additions & 0 deletions paddleformers/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.recompute.recompute import recompute
from paddle.distributed.flex_checkpoint.dcp.sharded_weight import (
build_sharded_state_dict,
)

from ..refined_recompute import (
RRColumnParallelLinear,
Expand Down Expand Up @@ -1988,6 +1991,14 @@ def forward(self, hidden_states, tensor_parallel_output=None):
)
return logits

def sharded_state_dict(
self,
structured_name_prefix: str = "",
):
axis = 0 if self.transpose_y else 1
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix)


class LlamaForCausalLM(LlamaPretrainedModel):
enable_to_static_method = True
Expand Down
13 changes: 13 additions & 0 deletions paddleformers/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3199,6 +3199,19 @@ def set_state_dict(self, state_dict, *args, **kwargs):
ret = super().set_state_dict(state_dict, *args, **kwargs)
return ret

def sharded_state_dict(self, *args, **kwargs):
sharded_state_dict = super().sharded_state_dict(*args, **kwargs)
if self._single_to_pp_mapping is None:
self._set_pipeline_name_mapping()
assert len(self._single_to_pp_mapping) > 0, "The pipeline stage must have parameters!"

for k in list(sharded_state_dict.keys()):
v = sharded_state_dict.pop(k)
v.tensor_key = self._pp_to_single_mapping[k]
sharded_state_dict[self._pp_to_single_mapping[k]] = v

return sharded_state_dict


def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):
"""
Expand Down
Loading