Skip to content

Commit 11cd0f7

Browse files
Less zip false positives (#5640)
* use magic number for zip * test * alternative version of zipfile.is_zipfile * Update src/datasets/utils/extract.py Co-authored-by: Albert Villanova del Moral <[email protected]> --------- Co-authored-by: Albert Villanova del Moral <[email protected]>
1 parent d862821 commit 11cd0f7

File tree

2 files changed

+81
-32
lines changed

2 files changed

+81
-32
lines changed

src/datasets/utils/extract.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import lzma
44
import os
55
import shutil
6+
import struct
67
import tarfile
78
import warnings
89
import zipfile
910
from abc import ABC, abstractmethod
1011
from pathlib import Path
11-
from typing import Optional, Union
12+
from typing import Dict, List, Optional, Type, Union
1213

1314
from .. import config
1415
from .filelock import FileLock
@@ -61,7 +62,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
6162

6263

6364
class MagicNumberBaseExtractor(BaseExtractor, ABC):
64-
magic_number = b""
65+
magic_numbers: List[bytes] = []
6566

6667
@staticmethod
6768
def read_magic_number(path: Union[Path, str], magic_number_length: int):
@@ -71,11 +72,12 @@ def read_magic_number(path: Union[Path, str], magic_number_length: int):
7172
@classmethod
7273
def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool:
7374
if not magic_number:
75+
magic_number_length = max(len(cls_magic_number) for cls_magic_number in cls.magic_numbers)
7476
try:
75-
magic_number = cls.read_magic_number(path, len(cls.magic_number))
77+
magic_number = cls.read_magic_number(path, magic_number_length)
7678
except OSError:
7779
return False
78-
return magic_number.startswith(cls.magic_number)
80+
return any(magic_number.startswith(cls_magic_number) for cls_magic_number in cls.magic_numbers)
7981

8082

8183
class TarExtractor(BaseExtractor):
@@ -128,7 +130,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
128130

129131

130132
class GzipExtractor(MagicNumberBaseExtractor):
131-
magic_number = b"\x1F\x8B"
133+
magic_numbers = [b"\x1F\x8B"]
132134

133135
@staticmethod
134136
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -137,10 +139,49 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
137139
shutil.copyfileobj(gzip_file, extracted_file)
138140

139141

140-
class ZipExtractor(BaseExtractor):
142+
class ZipExtractor(MagicNumberBaseExtractor):
143+
magic_numbers = [
144+
b"PK\x03\x04",
145+
b"PK\x05\x06", # empty archive
146+
b"PK\x07\x08", # spanned archive
147+
]
148+
141149
@classmethod
142-
def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool:
143-
return zipfile.is_zipfile(path)
150+
def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool:
151+
if super().is_extractable(path, magic_number=magic_number):
152+
return True
153+
try:
154+
# Alternative version of zipfile.is_zipfile that has less false positives, but misses executable zip archives.
155+
# From: https://github.com/python/cpython/pull/5053
156+
from zipfile import (
157+
_CD_SIGNATURE,
158+
_ECD_DISK_NUMBER,
159+
_ECD_DISK_START,
160+
_ECD_ENTRIES_TOTAL,
161+
_ECD_OFFSET,
162+
_ECD_SIZE,
163+
_EndRecData,
164+
sizeCentralDir,
165+
stringCentralDir,
166+
structCentralDir,
167+
)
168+
169+
with open(path, "rb") as fp:
170+
endrec = _EndRecData(fp)
171+
if endrec:
172+
if endrec[_ECD_ENTRIES_TOTAL] == 0 and endrec[_ECD_SIZE] == 0 and endrec[_ECD_OFFSET] == 0:
173+
return True # Empty zipfiles are still zipfiles
174+
elif endrec[_ECD_DISK_NUMBER] == endrec[_ECD_DISK_START]:
175+
fp.seek(endrec[_ECD_OFFSET]) # Central directory is on the same disk
176+
if fp.tell() == endrec[_ECD_OFFSET] and endrec[_ECD_SIZE] >= sizeCentralDir:
177+
data = fp.read(sizeCentralDir) # CD is where we expect it to be
178+
if len(data) == sizeCentralDir:
179+
centdir = struct.unpack(structCentralDir, data) # CD is the right size
180+
if centdir[_CD_SIGNATURE] == stringCentralDir:
181+
return True # First central directory entry has correct magic number
182+
return False
183+
except Exception: # catch all errors in case future python versions change the zipfile internals
184+
return False
144185

145186
@staticmethod
146187
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -151,7 +192,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
151192

152193

153194
class XzExtractor(MagicNumberBaseExtractor):
154-
magic_number = b"\xFD\x37\x7A\x58\x5A\x00"
195+
magic_numbers = [b"\xFD\x37\x7A\x58\x5A\x00"]
155196

156197
@staticmethod
157198
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -160,16 +201,8 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
160201
shutil.copyfileobj(compressed_file, extracted_file)
161202

162203

163-
class RarExtractor(BaseExtractor):
164-
RAR_ID = b"Rar!\x1a\x07\x00"
165-
RAR5_ID = b"Rar!\x1a\x07\x01\x00"
166-
167-
@classmethod
168-
def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool:
169-
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
170-
with open(path, "rb") as f:
171-
magic_number = f.read(len(cls.RAR5_ID))
172-
return magic_number == cls.RAR5_ID or magic_number.startswith(cls.RAR_ID)
204+
class RarExtractor(MagicNumberBaseExtractor):
205+
magic_numbers = [b"Rar!\x1a\x07\x00", b"Rar!\x1a\x07\x01\x00"] # RAR_ID # RAR5_ID
173206

174207
@staticmethod
175208
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -184,7 +217,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
184217

185218

186219
class ZstdExtractor(MagicNumberBaseExtractor):
187-
magic_number = b"\x28\xb5\x2F\xFD"
220+
magic_numbers = [b"\x28\xb5\x2F\xFD"]
188221

189222
@staticmethod
190223
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -198,7 +231,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
198231

199232

200233
class Bzip2Extractor(MagicNumberBaseExtractor):
201-
magic_number = b"\x42\x5A\x68"
234+
magic_numbers = [b"\x42\x5A\x68"]
202235

203236
@staticmethod
204237
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -208,7 +241,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
208241

209242

210243
class SevenZipExtractor(MagicNumberBaseExtractor):
211-
magic_number = b"\x37\x7A\xBC\xAF\x27\x1C"
244+
magic_numbers = [b"\x37\x7A\xBC\xAF\x27\x1C"]
212245

213246
@staticmethod
214247
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -222,7 +255,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
222255

223256

224257
class Lz4Extractor(MagicNumberBaseExtractor):
225-
magic_number = b"\x04\x22\x4D\x18"
258+
magic_numbers = [b"\x04\x22\x4D\x18"]
226259

227260
@staticmethod
228261
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
@@ -237,7 +270,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
237270

238271
class Extractor:
239272
# Put zip file to the last, b/c it is possible wrongly detected as zip (I guess it means: as tar or gzip)
240-
extractors = {
273+
extractors: Dict[str, Type[BaseExtractor]] = {
241274
"tar": TarExtractor,
242275
"gzip": GzipExtractor,
243276
"zip": ZipExtractor,
@@ -251,14 +284,12 @@ class Extractor:
251284

252285
@classmethod
253286
def _get_magic_number_max_length(cls):
254-
magic_number_max_length = 0
255-
for extractor in cls.extractors.values():
256-
if hasattr(extractor, "magic_number"):
257-
magic_number_length = len(extractor.magic_number)
258-
magic_number_max_length = (
259-
magic_number_length if magic_number_length > magic_number_max_length else magic_number_max_length
260-
)
261-
return magic_number_max_length
287+
return max(
288+
len(extractor_magic_number)
289+
for extractor in cls.extractors.values()
290+
if issubclass(extractor, MagicNumberBaseExtractor)
291+
for extractor_magic_number in extractor.magic_numbers
292+
)
262293

263294
@staticmethod
264295
def _read_magic_number(path: Union[Path, str], magic_number_length: int):

tests/test_extract.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import zipfile
23

34
import pytest
45

@@ -183,3 +184,20 @@ def test_tar_extract_insecure_files(
183184
for record in caplog.records:
184185
assert record.levelname == "ERROR"
185186
assert error_log in record.msg
187+
188+
189+
def test_is_zipfile_false_positive(tmpdir):
190+
# We should have less false positives than zipfile.is_zipfile
191+
# We do that by checking only the magic number
192+
not_a_zip_file = tmpdir / "not_a_zip_file"
193+
# From: https://github.com/python/cpython/pull/5053
194+
data = (
195+
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00"
196+
b"\x00\x02\x08\x06\x00\x00\x00\x99\x81\xb6'\x00\x00\x00\x15I"
197+
b"DATx\x01\x01\n\x00\xf5\xff\x00PK\x05\x06\x00PK\x06\x06\x07"
198+
b"\xac\x01N\xc6|a\r\x00\x00\x00\x00IEND\xaeB`\x82"
199+
)
200+
with not_a_zip_file.open("wb") as f:
201+
f.write(data)
202+
assert zipfile.is_zipfile(str(not_a_zip_file)) # is a false positive for `zipfile`
203+
assert not ZipExtractor.is_extractable(not_a_zip_file) # but we're right

0 commit comments

Comments
 (0)