Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,12 +979,7 @@ def train(
if resume_from_checkpoint is not None:
if not self.args.ignore_load_lr_and_optim:
model_sharded_state_dict = self.model.sharded_state_dict()
accessible_files = os.listdir(resume_from_checkpoint)
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
assert len(metadata_files) == 1, "Only support one metadata file now."
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
state_dict_metadata = metadata.state_dict_metadata
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
init_optimizer(self.optimizer)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(
Expand Down
64 changes: 16 additions & 48 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,85 +1363,53 @@ def set_comm_config(configs, attr, dict_obj):
return strategy


def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
def init_optimizer(optimizer):
"""
Initialize the optimizer's states according to its type.

For DygraphShardingOptimizer (V1), initializes accumulators for local parameters.
For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters.
For other cases, initializes accumulators for all parameters.

Args:
optimizer: The optimizer instance to be initialized.
"""
optimizer_state_names = [".moment1_0", ".moment2_0", ".beta1_pow_acc_0", ".beta2_pow_acc_0", ".w_0"]
inner_opt = getattr(optimizer, "_inner_opt", None)
static_to_struct_mapping = {}
model_sharded_state_dict = dict(sorted(model_sharded_state_dict.items()))
for k, v in model_sharded_state_dict.items():
if v.local_tensor.name not in static_to_struct_mapping:
static_to_struct_mapping[v.local_tensor.name] = k

if isinstance(inner_opt, DygraphShardingOptimizer):
local_params = optimizer._rank2params[optimizer._sharding_rank]
param_list = []
for param in local_params:
param_name = param.name
struct_name = static_to_struct_mapping[param_name]
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
continue
param_list.append(param)
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list)
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params)
return

elif isinstance(inner_opt, DygraphShardingOptimizerV2):

def init_param_optimizer_states(param_iter):
master_weights = {}
state_dict = {}
moments = ("moment1_0", "moment2_0")
betas = ("beta1_pow_acc_0", "beta2_pow_acc_0")
for static_name, shape, no_need_master_weights in param_iter:
if not no_need_master_weights:
master_weights[static_name] = paddle.zeros(shape, dtype="float32")
prefix = f"{static_name}_fp32_master_0_"
else:
prefix = f"{static_name}_"

for moment in moments:
key = f"{prefix}{moment}"
for static_name, shape in param_iter:
master_weights[static_name] = paddle.zeros(shape, dtype="float32")
for moment in ("moment1_0", "moment2_0"):
key = f"{static_name}_fp32_master_0_{moment}"
state_dict[key] = paddle.zeros(shape, dtype="float32")
for beta in betas:
key = f"{prefix}{beta}"
for beta in ("beta1_pow_acc_0", "beta2_pow_acc_0"):
key = f"{static_name}_fp32_master_0_{beta}"
state_dict[key] = paddle.zeros((1,), dtype="float32")
return master_weights, state_dict

def buffer_params():
for buffer in optimizer._comm_buffer_list:
for param_name, grad_view in buffer._sharding_param_grad_view.items():
struct_name = static_to_struct_mapping[param_name]
if not any(
struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names
):
continue
numel = grad_view._param.numel().item()
param_begin = grad_view._param_begin
param_end = grad_view._param_end
shape = (param_end - param_begin,)
no_need_master_weights = grad_view._param.dtype == paddle.float32

index = grad_view._index
padding_begin = index + numel
shape = (min(padding_begin, param_end) - param_begin,)
if shape[0] > 0:
yield param_name, shape, no_need_master_weights
yield param_name, shape

master_weights, state_dict = init_param_optimizer_states(buffer_params())
state_dict["master_weights"] = master_weights
state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06}
optimizer.set_state_dict(state_dict)
return
param_list = []
for param in optimizer._parameter_list:
param_name = param.name
struct_name = static_to_struct_mapping[param_name]
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
continue
param_list.append(param)
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list)
optimizer._create_accumulators(
paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list
)
Loading