@@ -62,14 +62,12 @@ def find_class(self, module, name):
62
62
raise Exception (f"global '{ module } /{ name } ' is forbidden" )
63
63
64
64
65
- allowed_zip_names = [ "archive/ data.pkl" , "archive/version" ]
66
- allowed_zip_names_re = re .compile (r"^archive/ data/\d+$" )
67
-
65
+ # Regular expression that accepts 'dirname/version', 'dirname/ data.pkl', and 'dirname/data/<number>'
66
+ allowed_zip_names_re = re .compile (r"^([^/]+)/(( data/\d+)|version|(data\.pkl)) $" )
67
+ data_pkl_re = re . compile ( r"^([^/]+)/data\.pkl$" )
68
68
69
69
def check_zip_filenames (filename , names ):
70
70
for name in names :
71
- if name in allowed_zip_names :
72
- continue
73
71
if allowed_zip_names_re .match (name ):
74
72
continue
75
73
@@ -82,8 +80,14 @@ def check_pt(filename, extra_handler):
82
80
# new pytorch format is a zip file
83
81
with zipfile .ZipFile (filename ) as z :
84
82
check_zip_filenames (filename , z .namelist ())
85
-
86
- with z .open ('archive/data.pkl' ) as file :
83
+
84
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
85
+ data_pkl_filenames = [f for f in z .namelist () if data_pkl_re .match (f )]
86
+ if len (data_pkl_filenames ) == 0 :
87
+ raise Exception (f"data.pkl not found in { filename } " )
88
+ if len (data_pkl_filenames ) > 1 :
89
+ raise Exception (f"Multiple data.pkl found in { filename } " )
90
+ with z .open (data_pkl_filenames [0 ]) as file :
87
91
unpickler = RestrictedUnpickler (file )
88
92
unpickler .extra_handler = extra_handler
89
93
unpickler .load ()
0 commit comments