|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import pickle |
| 4 | +import re |
| 5 | +import warnings |
3 | 6 | from typing import Any |
4 | 7 |
|
5 | 8 | import torch |
|
13 | 16 | ) |
14 | 17 | from torch_frame.data.multi_tensor import _MultiTensor |
15 | 18 | from torch_frame.data.stats import StatType |
16 | | -from torch_frame.typing import WITH_PT24, TensorData |
| 19 | +from torch_frame.typing import TensorData |
17 | 20 |
|
18 | 21 |
|
19 | 22 | def serialize_feat_dict( |
@@ -96,9 +99,30 @@ def load( |
96 | 99 | tuple: A tuple of loaded :class:`TensorFrame` object and |
97 | 100 | optional :obj:`col_stats`. |
98 | 101 | """ |
99 | | - tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24) |
| 102 | + if torch_frame.typing.WITH_PT24: |
| 103 | + try: |
| 104 | + tf_dict, col_stats = torch.load(path, weights_only=True) |
| 105 | + except pickle.UnpicklingError as e: |
| 106 | + error_msg = str(e) |
| 107 | + if "add_safe_globals" in error_msg: |
| 108 | + warn_msg = ("Weights only load failed. Please file an issue " |
| 109 | + "to make `torch.load(weights_only=True)` " |
| 110 | + "compatible in your case.") |
| 111 | + match = re.search(r'add_safe_globals\(.*?\)', error_msg) |
| 112 | + if match is not None: |
| 113 | + warnings.warn(f"{warn_msg} Please use " |
| 114 | + f"`torch.serialization.{match.group()}` to " |
| 115 | + f"allowlist this global.") |
| 116 | + else: |
| 117 | + warnings.warn(warn_msg) |
| 118 | + |
| 119 | + tf_dict, col_stats = torch.load(path, weights_only=False) |
| 120 | + else: |
| 121 | + raise e |
| 122 | + else: |
| 123 | + tf_dict, col_stats = torch.load(path, weights_only=False) |
| 124 | + |
100 | 125 | tf_dict['feat_dict'] = deserialize_feat_dict( |
101 | 126 | tf_dict.pop('feat_serialized_dict')) |
102 | | - tensor_frame = TensorFrame(**tf_dict) |
103 | | - tensor_frame.to(device) |
| 127 | + tensor_frame = TensorFrame(**tf_dict).to(device) |
104 | 128 | return tensor_frame, col_stats |
0 commit comments