@@ -23,11 +23,18 @@ def encode(*args):
23
23
24
24
25
25
class RestrictedUnpickler (pickle .Unpickler ):
26
+ extra_handler = None
27
+
26
28
def persistent_load (self , saved_id ):
27
29
assert saved_id [0 ] == 'storage'
28
30
return TypedStorage ()
29
31
30
32
def find_class (self , module , name ):
33
+ if self .extra_handler is not None :
34
+ res = self .extra_handler (module , name )
35
+ if res is not None :
36
+ return res
37
+
31
38
if module == 'collections' and name == 'OrderedDict' :
32
39
return getattr (collections , name )
33
40
if module == 'torch._utils' and name in ['_rebuild_tensor_v2' , '_rebuild_parameter' ]:
@@ -52,7 +59,7 @@ def find_class(self, module, name):
52
59
return set
53
60
54
61
# Forbid everything else.
55
- raise pickle . UnpicklingError (f"global '{ module } /{ name } ' is forbidden" )
62
+ raise Exception (f"global '{ module } /{ name } ' is forbidden" )
56
63
57
64
58
65
allowed_zip_names = ["archive/data.pkl" , "archive/version" ]
@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
69
76
raise Exception (f"bad file inside { filename } : { name } " )
70
77
71
78
72
- def check_pt (filename ):
79
+ def check_pt (filename , extra_handler ):
73
80
try :
74
81
75
82
# new pytorch format is a zip file
@@ -78,23 +85,50 @@ def check_pt(filename):
78
85
79
86
with z .open ('archive/data.pkl' ) as file :
80
87
unpickler = RestrictedUnpickler (file )
88
+ unpickler .extra_handler = extra_handler
81
89
unpickler .load ()
82
90
83
91
except zipfile .BadZipfile :
84
92
85
93
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
86
94
with open (filename , "rb" ) as file :
87
95
unpickler = RestrictedUnpickler (file )
96
+ unpickler .extra_handler = extra_handler
88
97
for i in range (5 ):
89
98
unpickler .load ()
90
99
91
100
92
101
def load (filename , * args , ** kwargs ):
102
+ return load_with_extra (filename , * args , ** kwargs )
103
+
104
+
105
+ def load_with_extra (filename , extra_handler = None , * args , ** kwargs ):
106
+ """
107
+ this functon is intended to be used by extensions that want to load models with
108
+ some extra classes in them that the usual unpickler would find suspicious.
109
+
110
+ Use the extra_handler argument to specify a function that takes module and field name as text,
111
+ and returns that field's value:
112
+
113
+ ```python
114
+ def extra(module, name):
115
+ if module == 'collections' and name == 'OrderedDict':
116
+ return collections.OrderedDict
117
+
118
+ return None
119
+
120
+ safe.load_with_extra('model.pt', extra_handler=extra)
121
+ ```
122
+
123
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
124
+ definitely unsafe.
125
+ """
126
+
93
127
from modules import shared
94
128
95
129
try :
96
130
if not shared .cmd_opts .disable_safe_unpickle :
97
- check_pt (filename )
131
+ check_pt (filename , extra_handler )
98
132
99
133
except pickle .UnpicklingError :
100
134
print (f"Error verifying pickled file from { filename } :" , file = sys .stderr )
0 commit comments