Skip to content

Commit e461477

Browse files
committed
Fixed safe.py for pytorch 1.13 ckpt files
1 parent 4b3c5bc commit e461477

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

modules/safe.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,12 @@ def find_class(self, module, name):
6262
raise Exception(f"global '{module}/{name}' is forbidden")
6363

6464

65-
allowed_zip_names = ["archive/data.pkl", "archive/version"]
66-
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
67-
65+
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
66+
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
67+
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
6868

6969
def check_zip_filenames(filename, names):
7070
for name in names:
71-
if name in allowed_zip_names:
72-
continue
7371
if allowed_zip_names_re.match(name):
7472
continue
7573

@@ -82,8 +80,14 @@ def check_pt(filename, extra_handler):
8280
# new pytorch format is a zip file
8381
with zipfile.ZipFile(filename) as z:
8482
check_zip_filenames(filename, z.namelist())
85-
86-
with z.open('archive/data.pkl') as file:
83+
84+
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
85+
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
86+
if len(data_pkl_filenames) == 0:
87+
raise Exception(f"data.pkl not found in {filename}")
88+
if len(data_pkl_filenames) > 1:
89+
raise Exception(f"Multiple data.pkl found in {filename}")
90+
with z.open(data_pkl_filenames[0]) as file:
8791
unpickler = RestrictedUnpickler(file)
8892
unpickler.extra_handler = extra_handler
8993
unpickler.load()

0 commit comments

Comments
 (0)