diff --git a/src/datasets/utils/extract.py b/src/datasets/utils/extract.py index 21245931a22..e9cc9ceab59 100644 --- a/src/datasets/utils/extract.py +++ b/src/datasets/utils/extract.py @@ -3,12 +3,13 @@ import lzma import os import shutil +import struct import tarfile import warnings import zipfile from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Union +from typing import Dict, List, Optional, Type, Union from .. import config from .filelock import FileLock @@ -61,7 +62,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class MagicNumberBaseExtractor(BaseExtractor, ABC): - magic_number = b"" + magic_numbers: List[bytes] = [] @staticmethod def read_magic_number(path: Union[Path, str], magic_number_length: int): @@ -71,11 +72,12 @@ def read_magic_number(path: Union[Path, str], magic_number_length: int): @classmethod def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool: if not magic_number: + magic_number_length = max(len(cls_magic_number) for cls_magic_number in cls.magic_numbers) try: - magic_number = cls.read_magic_number(path, len(cls.magic_number)) + magic_number = cls.read_magic_number(path, magic_number_length) except OSError: return False - return magic_number.startswith(cls.magic_number) + return any(magic_number.startswith(cls_magic_number) for cls_magic_number in cls.magic_numbers) class TarExtractor(BaseExtractor): @@ -128,7 +130,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class GzipExtractor(MagicNumberBaseExtractor): - magic_number = b"\x1F\x8B" + magic_numbers = [b"\x1F\x8B"] @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -137,10 +139,49 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None shutil.copyfileobj(gzip_file, extracted_file) -class ZipExtractor(BaseExtractor): +class ZipExtractor(MagicNumberBaseExtractor): + magic_numbers = [ + b"PK\x03\x04", + b"PK\x05\x06", # empty archive + b"PK\x07\x08", # spanned archive + ] + @classmethod - def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool: - return zipfile.is_zipfile(path) + def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool: + if super().is_extractable(path, magic_number=magic_number): + return True + try: + # Alternative version of zipfile.is_zipfile that has less false positives, but misses executable zip archives. + # From: https://github.com/python/cpython/pull/5053 + from zipfile import ( + _CD_SIGNATURE, + _ECD_DISK_NUMBER, + _ECD_DISK_START, + _ECD_ENTRIES_TOTAL, + _ECD_OFFSET, + _ECD_SIZE, + _EndRecData, + sizeCentralDir, + stringCentralDir, + structCentralDir, + ) + + with open(path, "rb") as fp: + endrec = _EndRecData(fp) + if endrec: + if endrec[_ECD_ENTRIES_TOTAL] == 0 and endrec[_ECD_SIZE] == 0 and endrec[_ECD_OFFSET] == 0: + return True # Empty zipfiles are still zipfiles + elif endrec[_ECD_DISK_NUMBER] == endrec[_ECD_DISK_START]: + fp.seek(endrec[_ECD_OFFSET]) # Central directory is on the same disk + if fp.tell() == endrec[_ECD_OFFSET] and endrec[_ECD_SIZE] >= sizeCentralDir: + data = fp.read(sizeCentralDir) # CD is where we expect it to be + if len(data) == sizeCentralDir: + centdir = struct.unpack(structCentralDir, data) # CD is the right size + if centdir[_CD_SIGNATURE] == stringCentralDir: + return True # First central directory entry has correct magic number + return False + except Exception: # catch all errors in case future python versions change the zipfile internals + return False @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -151,7 +192,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class XzExtractor(MagicNumberBaseExtractor): - magic_number = b"\xFD\x37\x7A\x58\x5A\x00" + magic_numbers = [b"\xFD\x37\x7A\x58\x5A\x00"] @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -160,16 +201,8 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None shutil.copyfileobj(compressed_file, extracted_file) -class RarExtractor(BaseExtractor): - RAR_ID = b"Rar!\x1a\x07\x00" - RAR5_ID = b"Rar!\x1a\x07\x01\x00" - - @classmethod - def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool: - """https://github.com/markokr/rarfile/blob/master/rarfile.py""" - with open(path, "rb") as f: - magic_number = f.read(len(cls.RAR5_ID)) - return magic_number == cls.RAR5_ID or magic_number.startswith(cls.RAR_ID) +class RarExtractor(MagicNumberBaseExtractor): + magic_numbers = [b"Rar!\x1a\x07\x00", b"Rar!\x1a\x07\x01\x00"] # RAR_ID # RAR5_ID @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -184,7 +217,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class ZstdExtractor(MagicNumberBaseExtractor): - magic_number = b"\x28\xb5\x2F\xFD" + magic_numbers = [b"\x28\xb5\x2F\xFD"] @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -198,7 +231,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class Bzip2Extractor(MagicNumberBaseExtractor): - magic_number = b"\x42\x5A\x68" + magic_numbers = [b"\x42\x5A\x68"] @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -208,7 +241,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class SevenZipExtractor(MagicNumberBaseExtractor): - magic_number = b"\x37\x7A\xBC\xAF\x27\x1C" + magic_numbers = [b"\x37\x7A\xBC\xAF\x27\x1C"] @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -222,7 +255,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class Lz4Extractor(MagicNumberBaseExtractor): - magic_number = b"\x04\x22\x4D\x18" + magic_numbers = [b"\x04\x22\x4D\x18"] @staticmethod def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: @@ -237,7 +270,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None class Extractor: # Put zip file to the last, b/c it is possible wrongly detected as zip (I guess it means: as tar or gzip) - extractors = { + extractors: Dict[str, Type[BaseExtractor]] = { "tar": TarExtractor, "gzip": GzipExtractor, "zip": ZipExtractor, @@ -251,14 +284,12 @@ class Extractor: @classmethod def _get_magic_number_max_length(cls): - magic_number_max_length = 0 - for extractor in cls.extractors.values(): - if hasattr(extractor, "magic_number"): - magic_number_length = len(extractor.magic_number) - magic_number_max_length = ( - magic_number_length if magic_number_length > magic_number_max_length else magic_number_max_length - ) - return magic_number_max_length + return max( + len(extractor_magic_number) + for extractor in cls.extractors.values() + if issubclass(extractor, MagicNumberBaseExtractor) + for extractor_magic_number in extractor.magic_numbers + ) @staticmethod def _read_magic_number(path: Union[Path, str], magic_number_length: int): diff --git a/tests/test_extract.py b/tests/test_extract.py index 943d3a14ba7..186d65fd0ba 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -1,4 +1,5 @@ import os +import zipfile import pytest @@ -183,3 +184,20 @@ def test_tar_extract_insecure_files( for record in caplog.records: assert record.levelname == "ERROR" assert error_log in record.msg + + +def test_is_zipfile_false_positive(tmpdir): + # We should have less false positives than zipfile.is_zipfile + # We do that by checking only the magic number + not_a_zip_file = tmpdir / "not_a_zip_file" + # From: https://github.com/python/cpython/pull/5053 + data = ( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00" + b"\x00\x02\x08\x06\x00\x00\x00\x99\x81\xb6'\x00\x00\x00\x15I" + b"DATx\x01\x01\n\x00\xf5\xff\x00PK\x05\x06\x00PK\x06\x06\x07" + b"\xac\x01N\xc6|a\r\x00\x00\x00\x00IEND\xaeB`\x82" + ) + with not_a_zip_file.open("wb") as f: + f.write(data) + assert zipfile.is_zipfile(str(not_a_zip_file)) # is a false positive for `zipfile` + assert not ZipExtractor.is_extractable(not_a_zip_file) # but we're right