Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit ce0e0a2

Browse files
Adding tarfile member sanitization to extractall() (#1593)
1 parent 5d4bc9e commit ce0e0a2

File tree

2 files changed

+100
-5
lines changed

2 files changed

+100
-5
lines changed

scripts/datasets/general_nlp_benchmark/prepare_text_classification.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,26 @@ def main(args):
5353
sha1_hash = _URL_FILE_STATS[file_url]
5454
download_path = download(file_url, args.cache_path, sha1_hash=sha1_hash)
5555
with tarfile.open(download_path) as f:
56-
f.extractall(task_dir_path)
56+
def is_within_directory(directory, target):
57+
58+
abs_directory = os.path.abspath(directory)
59+
abs_target = os.path.abspath(target)
60+
61+
prefix = os.path.commonprefix([abs_directory, abs_target])
62+
63+
return prefix == abs_directory
64+
65+
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
66+
67+
for member in tar.getmembers():
68+
member_path = os.path.join(path, member.name)
69+
if not is_within_directory(path, member_path):
70+
raise Exception("Attempted Path Traversal in Tar File")
71+
72+
tar.extractall(path, members, numeric_owner=numeric_owner)
73+
74+
75+
safe_extract(f, task_dir_path)
5776
if task == 'imdb':
5877
shutil.move(os.path.join(task_dir_path, 'imdb', 'train.parquet'),
5978
os.path.join(task_dir_path, 'train.parquet'))

scripts/datasets/music_generation/prepare_music_midi.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,92 @@ def main(args):
7474
save_dir))
7575
if args.dataset == 'lmd_full':
7676
with tarfile.open(target_download_location) as f:
77-
f.extractall(save_dir)
77+
def is_within_directory(directory, target):
78+
79+
abs_directory = os.path.abspath(directory)
80+
abs_target = os.path.abspath(target)
81+
82+
prefix = os.path.commonprefix([abs_directory, abs_target])
83+
84+
return prefix == abs_directory
85+
86+
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
87+
88+
for member in tar.getmembers():
89+
member_path = os.path.join(path, member.name)
90+
if not is_within_directory(path, member_path):
91+
raise Exception("Attempted Path Traversal in Tar File")
92+
93+
tar.extractall(path, members, numeric_owner=numeric_owner)
94+
95+
96+
safe_extract(f, save_dir)
7897
elif args.dataset == 'lmd_matched':
7998
with tarfile.open(target_download_location) as f:
80-
f.extractall(save_dir)
99+
def is_within_directory(directory, target):
100+
101+
abs_directory = os.path.abspath(directory)
102+
abs_target = os.path.abspath(target)
103+
104+
prefix = os.path.commonprefix([abs_directory, abs_target])
105+
106+
return prefix == abs_directory
107+
108+
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
109+
110+
for member in tar.getmembers():
111+
member_path = os.path.join(path, member.name)
112+
if not is_within_directory(path, member_path):
113+
raise Exception("Attempted Path Traversal in Tar File")
114+
115+
tar.extractall(path, members, numeric_owner=numeric_owner)
116+
117+
118+
safe_extract(f, save_dir)
81119
elif args.dataset == 'lmd_aligned':
82120
with tarfile.open(target_download_location) as f:
83-
f.extractall(save_dir)
121+
def is_within_directory(directory, target):
122+
123+
abs_directory = os.path.abspath(directory)
124+
abs_target = os.path.abspath(target)
125+
126+
prefix = os.path.commonprefix([abs_directory, abs_target])
127+
128+
return prefix == abs_directory
129+
130+
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
131+
132+
for member in tar.getmembers():
133+
member_path = os.path.join(path, member.name)
134+
if not is_within_directory(path, member_path):
135+
raise Exception("Attempted Path Traversal in Tar File")
136+
137+
tar.extractall(path, members, numeric_owner=numeric_owner)
138+
139+
140+
safe_extract(f, save_dir)
84141
elif args.dataset == 'clean_midi':
85142
with tarfile.open(target_download_location) as f:
86-
f.extractall(save_dir)
143+
def is_within_directory(directory, target):
144+
145+
abs_directory = os.path.abspath(directory)
146+
abs_target = os.path.abspath(target)
147+
148+
prefix = os.path.commonprefix([abs_directory, abs_target])
149+
150+
return prefix == abs_directory
151+
152+
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
153+
154+
for member in tar.getmembers():
155+
member_path = os.path.join(path, member.name)
156+
if not is_within_directory(path, member_path):
157+
raise Exception("Attempted Path Traversal in Tar File")
158+
159+
tar.extractall(path, members, numeric_owner=numeric_owner)
160+
161+
162+
safe_extract(f, save_dir)
87163
elif args.dataset == 'maestro_v1':
88164
with zipfile.ZipFile(target_download_location, 'r') as fobj:
89165
fobj.extractall(save_dir)

0 commit comments

Comments
 (0)