diff --git a/paddlenlp/trainer/utils/ckpt_converter.py b/paddlenlp/trainer/utils/ckpt_converter.py index 198cd8df35c9..c4e376e49f65 100644 --- a/paddlenlp/trainer/utils/ckpt_converter.py +++ b/paddlenlp/trainer/utils/ckpt_converter.py @@ -20,16 +20,48 @@ import paddle from paddle.distributed.fleet.utils.log_util import logger -from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( - _load_state_dict, - get_rank_to_read_files, -) -from paddle.distributed.flex_checkpoint.dcp.metadata import ( - LocalTensorIndex, - LocalTensorMetadata, - Metadata, -) -from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict + +try: + from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( + _load_state_dict, + get_rank_to_read_files, + ) +except ModuleNotFoundError: + try: + from paddle.distributed.checkpoint.load_state_dict import ( + _load_state_dict, + get_rank_to_read_files, + ) + except ModuleNotFoundError: + _load_state_dict = None + get_rank_to_read_files = None + + +try: + from paddle.distributed.flex_checkpoint.dcp.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, + ) +except ModuleNotFoundError: + try: + from paddle.distributed.checkpoint.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, + ) + except ModuleNotFoundError: + LocalTensorIndex = None + LocalTensorMetadata = None + Metadata = None + +try: + from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict +except ModuleNotFoundError: + try: + from paddle.distributed.checkpoint.utils import flatten_state_dict + except ModuleNotFoundError: + flatten_state_dict = None MODEL_WEIGHT_SUFFIX = ".pdparams" OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" @@ -206,7 +238,7 @@ def gen_metadata_and_prepare_source_state_dict(self): global_offset = [0] * self.tp_degree for item in shard_info: tp_rank = item[0]["tp_rank"] - state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}" + state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) local_tensor_meta_data = LocalTensorMetadata((global_offset[tp_rank],), item[1], item[2]) local_tensor_index = LocalTensorIndex(state_name_with_tp_rank, (global_offset[tp_rank],)) global_offset[tp_rank] += item[1][0] @@ -225,7 +257,7 @@ def gen_metadata_and_prepare_source_state_dict(self): renamed_state_dict = {} (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) for state_name, state_value in state_dict.items(): - state_name_with_tp_rank = state_name + "_tp" + f"{tp_rank:02d}" + state_name_with_tp_rank = state_name + "_tp" + "{:02d}".format(tp_rank) renamed_state_dict[state_name_with_tp_rank] = state_value source_state_dict_for_merge_sharding[file_name] = renamed_state_dict @@ -235,7 +267,7 @@ def gen_metadata_and_prepare_source_state_dict(self): sharding_metas_keys = [] for i in range(self.tp_degree): for j in range(self.pp_degree): - sharding_metas_keys.append(f"tp{i:02d}_pp{j:02d}") + sharding_metas_keys.append("tp{:02d}_pp{:02d}".format(i, j)) for key in sharding_metas_keys: param_meta = self.model_meta["sharding_metas"][key]["param_meta"] for param_name, param_shape_and_dtype in param_meta.items(): @@ -253,7 +285,7 @@ def gen_metadata_and_prepare_source_state_dict(self): all_param_meta = {} for i in range(self.tp_degree): for j in range(self.pp_degree): - key = f"tp{i:02d}_pp{j:02d}" + key = "tp{:02d}_pp{:02d}".format(i, j) param_meta = self.model_meta["sharding_metas"][key]["param_meta"] for param_name, param_shape_and_dtype in param_meta.items(): all_param_meta[param_name] = param_shape_and_dtype @@ -269,7 +301,7 @@ def gen_metadata_and_prepare_source_state_dict(self): with paddle.base.dygraph.guard(place=paddle.CPUPlace()): for key in cur_rank_need_load_model_state_keys: for tp_rank in range(self.tp_degree): - tp_rank_suffix = f"_tp{tp_rank:02d}" + tp_rank_suffix = "_tp{:02d}".format(tp_rank) optimizer_state_dict[key + ".moment1" + tp_rank_suffix] = paddle.zeros( (param_flattened_shapes[key],), "float32" ) @@ -353,7 +385,7 @@ def gen_metadata_and_prepare_source_state_dict(self): else: concat_optimier_state_dict[opt_state_name_removed_tp_rank] = tp_tensors[0] - fake_file_name = f"{self.cur_rank:02d}" + ".distcp" + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" local_tensor_meta_data = {} local_tensor_index = {} for k, v in concat_optimier_state_dict.items(): @@ -472,7 +504,7 @@ def gen_metadata_and_prepare_source_state_dict(self): reshaped_v = v.reshape(shape) target_state_dict[k] = reshaped_v - fake_file_name = f"{self.cur_rank:02d}" + ".distcp" + fake_file_name = "{:02d}".format(self.cur_rank) + ".distcp" local_tensor_meta_data = {} local_tensor_index = {} for k, v in target_state_dict.items(): @@ -911,7 +943,7 @@ def rename_using_model_meta(self, file_name): self.model_meta = json.load(file) (tp_rank, pp_rank, sharding_rank) = self.get_distribution_rank_from_file_name(file_name) - dist_strategy_key = "tp" + f"{tp_rank:02d}" + "_" + "pp" + f"{pp_rank:02d}" + dist_strategy_key = "tp" + "{:02d}".format(tp_rank) + "_" + "pp" + "{:02d}".format(pp_rank) # Map model weight names to their corresponding names of master_weights in the optimizer state. if file_name.endswith(OPTIMIZER_WEIGHT_SUFFIX): structure_name_mapping = self.model_meta["sharding_metas"][dist_strategy_key]["structure_name_mapping"]