From 1891d0353f2c2a6a1715b56f268ea37fc0c5756f Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 27 Aug 2025 14:01:50 +0800 Subject: [PATCH] try import flex --- paddlenlp/trainer/utils/ckpt_converter.py | 52 ++++++++++++++++++----- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/paddlenlp/trainer/utils/ckpt_converter.py b/paddlenlp/trainer/utils/ckpt_converter.py index 23f085e18f44..c4e376e49f65 100644 --- a/paddlenlp/trainer/utils/ckpt_converter.py +++ b/paddlenlp/trainer/utils/ckpt_converter.py @@ -19,18 +19,50 @@ from typing import List, Union import paddle -from paddle.distributed.checkpoint.load_state_dict import ( - _load_state_dict, - get_rank_to_read_files, -) -from paddle.distributed.checkpoint.metadata import ( - LocalTensorIndex, - LocalTensorMetadata, - Metadata, -) -from paddle.distributed.checkpoint.utils import flatten_state_dict from paddle.distributed.fleet.utils.log_util import logger +try: + from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( + _load_state_dict, + get_rank_to_read_files, + ) +except ModuleNotFoundError: + try: + from paddle.distributed.checkpoint.load_state_dict import ( + _load_state_dict, + get_rank_to_read_files, + ) + except ModuleNotFoundError: + _load_state_dict = None + get_rank_to_read_files = None + + +try: + from paddle.distributed.flex_checkpoint.dcp.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, + ) +except ModuleNotFoundError: + try: + from paddle.distributed.checkpoint.metadata import ( + LocalTensorIndex, + LocalTensorMetadata, + Metadata, + ) + except ModuleNotFoundError: + LocalTensorIndex = None + LocalTensorMetadata = None + Metadata = None + +try: + from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict +except ModuleNotFoundError: + try: + from paddle.distributed.checkpoint.utils import flatten_state_dict + except ModuleNotFoundError: + flatten_state_dict = None + MODEL_WEIGHT_SUFFIX = ".pdparams" OPTIMIZER_WEIGHT_SUFFIX = ".pdopt" SCHEDULER_NAME = "scheduler.pdparams"