Skip to content

Commit a3b73c4

Browse files
authored
Fail torch.load(weights=True) gracefully (#448)
1 parent 546f1a2 commit a3b73c4

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

test/utils/test_io.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import shutil
44
import tempfile
55

6+
import pytest
7+
68
import torch_frame
7-
from torch_frame import load, save
9+
from torch_frame import TensorFrame, load, save
810
from torch_frame.config.text_embedder import TextEmbedderConfig
911
from torch_frame.config.text_tokenizer import TextTokenizerConfig
1012
from torch_frame.datasets import FakeDataset
@@ -114,3 +116,21 @@ def test_save_load_tensor_frame():
114116
tf, col_stats = load(path)
115117
assert dataset.col_stats == col_stats
116118
assert dataset.tensor_frame == tf
119+
120+
121+
class UntrustedClass:
122+
pass
123+
124+
125+
@pytest.mark.skipif(
126+
not torch_frame.typing.WITH_PT24,
127+
reason='Requres PyTorch 2.4',
128+
)
129+
def test_load_weights_only_gracefully(tmpdir):
130+
save(
131+
tensor_frame=TensorFrame({}, {}),
132+
col_stats={'a': UntrustedClass()},
133+
path=tmpdir.join('tf.pt'),
134+
)
135+
with pytest.warns(UserWarning, match='Weights only load failed'):
136+
load(tmpdir.join('tf.pt'))

torch_frame/utils/io.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import pickle
4+
import re
5+
import warnings
36
from typing import Any
47

58
import torch
@@ -13,7 +16,7 @@
1316
)
1417
from torch_frame.data.multi_tensor import _MultiTensor
1518
from torch_frame.data.stats import StatType
16-
from torch_frame.typing import WITH_PT24, TensorData
19+
from torch_frame.typing import TensorData
1720

1821

1922
def serialize_feat_dict(
@@ -96,9 +99,30 @@ def load(
9699
tuple: A tuple of loaded :class:`TensorFrame` object and
97100
optional :obj:`col_stats`.
98101
"""
99-
tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24)
102+
if torch_frame.typing.WITH_PT24:
103+
try:
104+
tf_dict, col_stats = torch.load(path, weights_only=True)
105+
except pickle.UnpicklingError as e:
106+
error_msg = str(e)
107+
if "add_safe_globals" in error_msg:
108+
warn_msg = ("Weights only load failed. Please file an issue "
109+
"to make `torch.load(weights_only=True)` "
110+
"compatible in your case.")
111+
match = re.search(r'add_safe_globals\(.*?\)', error_msg)
112+
if match is not None:
113+
warnings.warn(f"{warn_msg} Please use "
114+
f"`torch.serialization.{match.group()}` to "
115+
f"allowlist this global.")
116+
else:
117+
warnings.warn(warn_msg)
118+
119+
tf_dict, col_stats = torch.load(path, weights_only=False)
120+
else:
121+
raise e
122+
else:
123+
tf_dict, col_stats = torch.load(path, weights_only=False)
124+
100125
tf_dict['feat_dict'] = deserialize_feat_dict(
101126
tf_dict.pop('feat_serialized_dict'))
102-
tensor_frame = TensorFrame(**tf_dict)
103-
tensor_frame.to(device)
127+
tensor_frame = TensorFrame(**tf_dict).to(device)
104128
return tensor_frame, col_stats

0 commit comments

Comments
 (0)