Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 197 additions & 26 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import deepcopy
import numpy as np
import paddle
import paddle.amp.auto_cast as autocast
Expand All @@ -54,6 +55,9 @@
except:
core = None
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizerV2,
)
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
)
Expand Down Expand Up @@ -102,6 +106,8 @@
except:
pass

from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ShardedWeight

from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
from ..transformers.model_utils import (
PretrainedModel,
Expand Down Expand Up @@ -226,6 +232,11 @@ def in_auto_parallel_align_mode():
return False


MODEL_STATE_DIC = "model_state"
OPTIMIZER_STATE_DIC = "optimizer_state"
MASTER_WEIGHT_DIC = "master_weight"


__all__ = ["Trainer"]


Expand Down Expand Up @@ -842,6 +853,140 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):

logger.info("Create zero cost checkpoint manager done.")

def _load_flex_checkpoint(self, resume_from_checkpoint):
model_sharded_state_dict = self.model.sharded_state_dict()
master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC)
opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC)
model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC)
if not self.args.ignore_load_lr_and_optim:
state_dict_metadata = {}
metadata_paths = [
os.path.join(model_states_path, "0.metadata"),
os.path.join(opt_states_path, "0.metadata"),
os.path.join(master_weights_path, "0.metadata"),
]

for metadata_file in metadata_paths:
if not os.path.exists(metadata_file):
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
metadata = paddle.load(metadata_file)
if hasattr(metadata, "state_dict_metadata"):
state_dict_metadata.update(metadata.state_dict_metadata)
else:
raise AttributeError(
f"Loaded metadata from {metadata_file} does not have 'state_dict_metadata' attribute"
)

init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)

optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
for k, v in optimizer_sharded_state_dict.items():
v.local_tensor._clear_to_zero_allocation()

if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2):
color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
for color, _comm_buffer_list in color_to_comm_buffer_list.items():
for comm_buffer in _comm_buffer_list:
comm_buffer._clear_param_storage()
else:
state_dict = self.model.state_dict()
for k, v in state_dict.items():
v._clear_to_zero_allocation()

opt_states = {}
master_weights = {}
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
opt_states[k] = v

for k, v in opt_states.items():
new_v = ShardedWeight(
key=v.key,
local_tensor=paddle.zeros_like(v.local_tensor),
local_shape=deepcopy(v.local_shape),
global_shape=deepcopy(v.global_shape),
global_offset=deepcopy(v.global_offset),
is_flattened=v.is_flattened,
flattened_range=deepcopy(v.flattened_range),
)
opt_states[k] = new_v

dist.load_state_dict(
opt_states,
opt_states_path,
aoa_config=self.args.aoa_config,
)

optimizer_state_pin = {}

for k, v in opt_states.items():
tmp = v.local_tensor
optimizer_state_pin[k] = tmp.pin_memory()
tmp._clear_to_zero_allocation()
del tmp

for k, v in master_weights.items():
new_v = ShardedWeight(
key=v.key,
local_tensor=paddle.zeros_like(v.local_tensor),
local_shape=deepcopy(v.local_shape),
global_shape=deepcopy(v.global_shape),
global_offset=deepcopy(v.global_offset),
is_flattened=v.is_flattened,
flattened_range=deepcopy(v.flattened_range),
)
master_weights[k] = new_v

dist.load_state_dict(
master_weights,
master_weights_path,
aoa_config=self.args.aoa_config,
)

master_weights_pin = {}

for k, v in master_weights.items():
tmp = v.local_tensor
master_weights_pin[k] = tmp.pin_memory()
tmp._clear_to_zero_allocation()
del tmp

optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)

optimizer_sharded_state_dict_pin = {**master_weights_pin, **optimizer_state_pin}

for k, v in optimizer_sharded_state_dict.items():
source_tensor = optimizer_sharded_state_dict_pin[k]
target_tensor = paddle.zeros_like(v.local_tensor)
if source_tensor.place != target_tensor.place:
source_tensor = source_tensor.to(target_tensor.place)
paddle.assign(source_tensor, target_tensor)
target_tensor_pin = target_tensor.cpu()
del target_tensor
target_tensor_pin._share_buffer_to(v.local_tensor)
del source_tensor

if isinstance(self.optimizer.inner_opt, DygraphShardingOptimizerV2):
color_to_comm_buffer_list = self.optimizer._color_to_comm_buffer_list
for color, _comm_buffer_list in color_to_comm_buffer_list.items():
for comm_buffer in _comm_buffer_list:
comm_buffer._reset_param_storage()
else:
state_dict = self.model.state_dict()
for k, v in state_dict.items():
new_v = paddle.zeros_like(v)
new_v._share_buffer_to(v)

self._load_scheduler(resume_from_checkpoint)

dist.load_state_dict(
model_sharded_state_dict,
model_states_path,
aoa_config=self.args.aoa_config,
)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
Expand Down Expand Up @@ -975,28 +1120,8 @@ def train(
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
if not self.args.ignore_load_lr_and_optim:
model_sharded_state_dict = self.model.sharded_state_dict()
accessible_files = os.listdir(resume_from_checkpoint)
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
assert len(metadata_files) == 1, "Only support one metadata file now."
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
state_dict_metadata = metadata.state_dict_metadata
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
self._load_scheduler(resume_from_checkpoint)
else:
model_sharded_state_dict = self.model.sharded_state_dict()
sharded_state_dict = model_sharded_state_dict
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
self._load_flex_checkpoint(resume_from_checkpoint)
else:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
Expand Down Expand Up @@ -2794,7 +2919,12 @@ def _save_checkpoint(self, model, metrics=None):

if self.args.save_checkpoint_format == "flex_checkpoint":
model_sharded_state_dict = self.model.sharded_state_dict()
os.makedirs(output_dir, exist_ok=True)
model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC)
os.makedirs(model_state_dict_path, exist_ok=True)
dist.save_state_dict(
model_sharded_state_dict,
model_state_dict_path,
)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
Expand Down Expand Up @@ -2858,10 +2988,26 @@ def _save_checkpoint(self, model, metrics=None):
)
else:
if self.args.save_checkpoint_format == "flex_checkpoint":
optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC)
optimizer_states = {}
master_weights = {}

model_sharded_state_dict = self.model.sharded_state_dict()
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
optimizer_states[k] = v

dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
optimizer_states,
optimizer_state_dict_path,
)
master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC)
dist.save_state_dict(
master_weights,
master_weights_path,
)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
Expand Down Expand Up @@ -2919,10 +3065,35 @@ def _save_checkpoint(self, model, metrics=None):
)
elif self.args.save_checkpoint_format == "flex_checkpoint":
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
model_sharded_state_dict = self.model.sharded_state_dict()
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC)
os.makedirs(model_state_dict_path, exist_ok=True)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
model_sharded_state_dict,
model_state_dict_path,
)
if not self.args.ignore_save_lr_and_optim:
optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC)
optimizer_states = {}
master_weights = {}
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
optimizer_states[k] = v

dist.save_state_dict(
optimizer_states,
optimizer_state_dict_path,
)

master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC)
dist.save_state_dict(
master_weights,
master_weights_path,
)

if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
Expand Down
Loading