Skip to content

Commit 3e2484a

Browse files
authored
Try import ckpt convert (#2476)
1 parent f23d540 commit 3e2484a

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

paddleformers/trainer/utils/ckpt_converter.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,48 @@
2020

2121
import paddle
2222
from paddle.distributed.fleet.utils.log_util import logger
23-
from paddle.distributed.flex_checkpoint.dcp.load_state_dict import (
24-
_load_state_dict,
25-
get_rank_to_read_files,
26-
)
27-
from paddle.distributed.flex_checkpoint.dcp.metadata import (
28-
LocalTensorIndex,
29-
LocalTensorMetadata,
30-
Metadata,
31-
)
32-
from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict
23+
24+
try:
25+
from paddle.distributed.flex_checkpoint.dcp.load_state_dict import (
26+
_load_state_dict,
27+
get_rank_to_read_files,
28+
)
29+
except ModuleNotFoundError:
30+
try:
31+
from paddle.distributed.checkpoint.load_state_dict import (
32+
_load_state_dict,
33+
get_rank_to_read_files,
34+
)
35+
except ModuleNotFoundError:
36+
_load_state_dict = None
37+
get_rank_to_read_files = None
38+
39+
40+
try:
41+
from paddle.distributed.flex_checkpoint.dcp.metadata import (
42+
LocalTensorIndex,
43+
LocalTensorMetadata,
44+
Metadata,
45+
)
46+
except ModuleNotFoundError:
47+
try:
48+
from paddle.distributed.checkpoint.metadata import (
49+
LocalTensorIndex,
50+
LocalTensorMetadata,
51+
Metadata,
52+
)
53+
except ModuleNotFoundError:
54+
LocalTensorIndex = None
55+
LocalTensorMetadata = None
56+
Metadata = None
57+
58+
try:
59+
from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict
60+
except ModuleNotFoundError:
61+
try:
62+
from paddle.distributed.checkpoint.utils import flatten_state_dict
63+
except ModuleNotFoundError:
64+
flatten_state_dict = None
3365

3466
MODEL_WEIGHT_SUFFIX = ".pdparams"
3567
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"

0 commit comments

Comments
 (0)