|
20 | 20 |
|
21 | 21 | import paddle
|
22 | 22 | from paddle.distributed.fleet.utils.log_util import logger
|
23 |
| -from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( |
24 |
| - _load_state_dict, |
25 |
| - get_rank_to_read_files, |
26 |
| -) |
27 |
| -from paddle.distributed.flex_checkpoint.dcp.metadata import ( |
28 |
| - LocalTensorIndex, |
29 |
| - LocalTensorMetadata, |
30 |
| - Metadata, |
31 |
| -) |
32 |
| -from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict |
| 23 | + |
| 24 | +try: |
| 25 | + from paddle.distributed.flex_checkpoint.dcp.load_state_dict import ( |
| 26 | + _load_state_dict, |
| 27 | + get_rank_to_read_files, |
| 28 | + ) |
| 29 | +except ModuleNotFoundError: |
| 30 | + try: |
| 31 | + from paddle.distributed.checkpoint.load_state_dict import ( |
| 32 | + _load_state_dict, |
| 33 | + get_rank_to_read_files, |
| 34 | + ) |
| 35 | + except ModuleNotFoundError: |
| 36 | + _load_state_dict = None |
| 37 | + get_rank_to_read_files = None |
| 38 | + |
| 39 | + |
| 40 | +try: |
| 41 | + from paddle.distributed.flex_checkpoint.dcp.metadata import ( |
| 42 | + LocalTensorIndex, |
| 43 | + LocalTensorMetadata, |
| 44 | + Metadata, |
| 45 | + ) |
| 46 | +except ModuleNotFoundError: |
| 47 | + try: |
| 48 | + from paddle.distributed.checkpoint.metadata import ( |
| 49 | + LocalTensorIndex, |
| 50 | + LocalTensorMetadata, |
| 51 | + Metadata, |
| 52 | + ) |
| 53 | + except ModuleNotFoundError: |
| 54 | + LocalTensorIndex = None |
| 55 | + LocalTensorMetadata = None |
| 56 | + Metadata = None |
| 57 | + |
| 58 | +try: |
| 59 | + from paddle.distributed.flex_checkpoint.dcp.utils import flatten_state_dict |
| 60 | +except ModuleNotFoundError: |
| 61 | + try: |
| 62 | + from paddle.distributed.checkpoint.utils import flatten_state_dict |
| 63 | + except ModuleNotFoundError: |
| 64 | + flatten_state_dict = None |
33 | 65 |
|
34 | 66 | MODEL_WEIGHT_SUFFIX = ".pdparams"
|
35 | 67 | OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
|
|
0 commit comments