Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 50 additions & 18 deletions paddlenlp/trainer/utils/ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,48 @@

import paddle
from paddle.distributed.fleet.utils.log_util import logger
from paddle.distributed.flex_checkpoint.dcp.load_state_dict import (
_load_state_dict,
get_rank_to_read_files,
)
from paddle.distributed.flex_checkpoint.dcp.metadata import (
LocalTensorIndex,
LocalTensorMetadata,
Metadata,
)
from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict

try:
from paddle.distributed.flex_checkpoint.dcp.load_state_dict import (
_load_state_dict,
get_rank_to_read_files,
)
except ModuleNotFoundError:
try:
from paddle.distributed.checkpoint.load_state_dict import (
_load_state_dict,
get_rank_to_read_files,
)
except ModuleNotFoundError:
_load_state_dict = None
get_rank_to_read_files = None


try:
from paddle.distributed.flex_checkpoint.dcp.metadata import (
LocalTensorIndex,
LocalTensorMetadata,
Metadata,
)
except ModuleNotFoundError:
try:
from paddle.distributed.checkpoint.metadata import (
LocalTensorIndex,
LocalTensorMetadata,
Metadata,
)
except ModuleNotFoundError:
LocalTensorIndex = None
LocalTensorMetadata = None
Metadata = None

try:
from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict
except ModuleNotFoundError:
try:
from paddle.distributed.checkpoint.utils import flatten_state_dict
except ModuleNotFoundError:
flatten_state_dict = None

MODEL_WEIGHT_SUFFIX = ".pdparams"
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
Expand Down Expand Up @@ -206,7 +238,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
global_offset = [0] * self.tp_degree
for item in shard_info:
tp_rank = item[0]["tp_rank"]
state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}"
state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank)
local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2])
local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],))
global_offset[tp_rank] += item[1][0]
Expand All @@ -225,7 +257,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
renamed_state_dict = {}
(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
for state_name, state_value in state_dict.items():
state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}"
state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank)
renamed_state_dict[state_name_with_tp_rank] = state_value

source_state_dict_for_merge_sharding[file_name] = renamed_state_dict
Expand All @@ -235,7 +267,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
sharding_metas_keys = []
for i in range(self.tp_degree):
for j in range(self.pp_degree):
sharding_metas_keys.append(f"tp{i:02d}_pp{j:02d}")
sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j))
for key in sharding_metas_keys:
param_meta = self.model_meta["sharding_metas"][key]["param_meta"]
for param_name, param_shape_and_dtype in param_meta.items():
Expand All @@ -253,7 +285,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
all_param_meta = {}
for i in range(self.tp_degree):
for j in range(self.pp_degree):
key = f"tp{i:02d}_pp{j:02d}"
key = "tp{:02d}_pp{:02d}".format(i, j)
param_meta = self.model_meta["sharding_metas"][key]["param_meta"]
for param_name, param_shape_and_dtype in param_meta.items():
all_param_meta[param_name] = param_shape_and_dtype
Expand All @@ -269,7 +301,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
with paddle.base.dygraph.guard(place=paddle.CPUPlace()):
for key in cur_rank_need_load_model_state_keys:
for tp_rank in range(self.tp_degree):
tp_rank_suffix = f"_tp{tp_rank:02d}"
tp_rank_suffix = "_tp{:02d}".format(tp_rank)
optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros(
(param_flattened_shapes[key],), "float32"
)
Expand Down Expand Up @@ -353,7 +385,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
else:
concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0]

fake_file_name = f"{self.cur_rank:02d}" + ".distcp"
fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp"
local_tensor_meta_data = {}
local_tensor_index = {}
for k, v in concat_optimier_state_dict.items():
Expand Down Expand Up @@ -472,7 +504,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
reshaped_v = v.reshape(shape)
target_state_dict[k] = reshaped_v

fake_file_name = f"{self.cur_rank:02d}" + ".distcp"
fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp"
local_tensor_meta_data = {}
local_tensor_index = {}
for k, v in target_state_dict.items():
Expand Down Expand Up @@ -911,7 +943,7 @@ def rename_using_model_meta(self, file_name):
self.model_meta = json.load(file)

(tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name)
dist_strategy_key = "tp" + f"{tp_rank:02d}" + "_" + "pp" + f"{pp_rank:02d}"
dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank)
# Map model weight names to their corresponding names of master_weights in the optimizer state.
if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX):
structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"]
Expand Down
Loading