Skip to content

Commit 6e4de5b

Browse files
committed
add load_with_extra function for modules to load checkpoints with extended whitelist
1 parent 9cd1a66 commit 6e4de5b

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

modules/safe.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,18 @@ def encode(*args):
2323

2424

2525
class RestrictedUnpickler(pickle.Unpickler):
26+
extra_handler = None
27+
2628
def persistent_load(self, saved_id):
2729
assert saved_id[0] == 'storage'
2830
return TypedStorage()
2931

3032
def find_class(self, module, name):
33+
if self.extra_handler is not None:
34+
res = self.extra_handler(module, name)
35+
if res is not None:
36+
return res
37+
3138
if module == 'collections' and name == 'OrderedDict':
3239
return getattr(collections, name)
3340
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
@@ -52,7 +59,7 @@ def find_class(self, module, name):
5259
return set
5360

5461
# Forbid everything else.
55-
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
62+
raise Exception(f"global '{module}/{name}' is forbidden")
5663

5764

5865
allowed_zip_names = ["archive/data.pkl", "archive/version"]
@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
6976
raise Exception(f"bad file inside {filename}: {name}")
7077

7178

72-
def check_pt(filename):
79+
def check_pt(filename, extra_handler):
7380
try:
7481

7582
# new pytorch format is a zip file
@@ -78,23 +85,50 @@ def check_pt(filename):
7885

7986
with z.open('archive/data.pkl') as file:
8087
unpickler = RestrictedUnpickler(file)
88+
unpickler.extra_handler = extra_handler
8189
unpickler.load()
8290

8391
except zipfile.BadZipfile:
8492

8593
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
8694
with open(filename, "rb") as file:
8795
unpickler = RestrictedUnpickler(file)
96+
unpickler.extra_handler = extra_handler
8897
for i in range(5):
8998
unpickler.load()
9099

91100

92101
def load(filename, *args, **kwargs):
102+
return load_with_extra(filename, *args, **kwargs)
103+
104+
105+
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
106+
"""
107+
this functon is intended to be used by extensions that want to load models with
108+
some extra classes in them that the usual unpickler would find suspicious.
109+
110+
Use the extra_handler argument to specify a function that takes module and field name as text,
111+
and returns that field's value:
112+
113+
```python
114+
def extra(module, name):
115+
if module == 'collections' and name == 'OrderedDict':
116+
return collections.OrderedDict
117+
118+
return None
119+
120+
safe.load_with_extra('model.pt', extra_handler=extra)
121+
```
122+
123+
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
124+
definitely unsafe.
125+
"""
126+
93127
from modules import shared
94128

95129
try:
96130
if not shared.cmd_opts.disable_safe_unpickle:
97-
check_pt(filename)
131+
check_pt(filename, extra_handler)
98132

99133
except pickle.UnpicklingError:
100134
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)

0 commit comments

Comments
 (0)