Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 63 additions & 32 deletions src/datasets/utils/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import lzma
import os
import shutil
import struct
import tarfile
import warnings
import zipfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Dict, List, Optional, Type, Union

from .. import config
from .filelock import FileLock
Expand Down Expand Up @@ -61,7 +62,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None


class MagicNumberBaseExtractor(BaseExtractor, ABC):
magic_number = b""
magic_numbers: List[bytes] = []

@staticmethod
def read_magic_number(path: Union[Path, str], magic_number_length: int):
Expand All @@ -71,11 +72,12 @@ def read_magic_number(path: Union[Path, str], magic_number_length: int):
@classmethod
def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool:
if not magic_number:
magic_number_length = max(len(cls_magic_number) for cls_magic_number in cls.magic_numbers)
try:
magic_number = cls.read_magic_number(path, len(cls.magic_number))
magic_number = cls.read_magic_number(path, magic_number_length)
except OSError:
return False
return magic_number.startswith(cls.magic_number)
return any(magic_number.startswith(cls_magic_number) for cls_magic_number in cls.magic_numbers)


class TarExtractor(BaseExtractor):
Expand Down Expand Up @@ -128,7 +130,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None


class GzipExtractor(MagicNumberBaseExtractor):
magic_number = b"\x1F\x8B"
magic_numbers = [b"\x1F\x8B"]

@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -137,10 +139,49 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
shutil.copyfileobj(gzip_file, extracted_file)


class ZipExtractor(BaseExtractor):
class ZipExtractor(MagicNumberBaseExtractor):
magic_numbers = [
b"PK\x03\x04",
b"PK\x05\x06", # empty archive
b"PK\x07\x08", # spanned archive
]

@classmethod
def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool:
return zipfile.is_zipfile(path)
def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool:
if super().is_extractable(path, magic_number=magic_number):
return True
try:
# Alternative version of zipfile.is_zipfile that has less false positives, but misses executable zip archives.
# From: https://github.com/python/cpython/pull/5053
from zipfile import (
_CD_SIGNATURE,
_ECD_DISK_NUMBER,
_ECD_DISK_START,
_ECD_ENTRIES_TOTAL,
_ECD_OFFSET,
_ECD_SIZE,
_EndRecData,
sizeCentralDir,
stringCentralDir,
structCentralDir,
)

with open(path, "rb") as fp:
endrec = _EndRecData(fp)
if endrec:
if endrec[_ECD_ENTRIES_TOTAL] == 0 and endrec[_ECD_SIZE] == 0 and endrec[_ECD_OFFSET] == 0:
return True # Empty zipfiles are still zipfiles
elif endrec[_ECD_DISK_NUMBER] == endrec[_ECD_DISK_START]:
fp.seek(endrec[_ECD_OFFSET]) # Central directory is on the same disk
if fp.tell() == endrec[_ECD_OFFSET] and endrec[_ECD_SIZE] >= sizeCentralDir:
data = fp.read(sizeCentralDir) # CD is where we expect it to be
if len(data) == sizeCentralDir:
centdir = struct.unpack(structCentralDir, data) # CD is the right size
if centdir[_CD_SIGNATURE] == stringCentralDir:
return True # First central directory entry has correct magic number
return False
except Exception: # catch all errors in case future python versions change the zipfile internals
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that this function as it is could return None. To fix this:

Suggested change
return False
return False
return False


@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -151,7 +192,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None


class XzExtractor(MagicNumberBaseExtractor):
magic_number = b"\xFD\x37\x7A\x58\x5A\x00"
magic_numbers = [b"\xFD\x37\x7A\x58\x5A\x00"]

@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -160,16 +201,8 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None
shutil.copyfileobj(compressed_file, extracted_file)


class RarExtractor(BaseExtractor):
RAR_ID = b"Rar!\x1a\x07\x00"
RAR5_ID = b"Rar!\x1a\x07\x01\x00"

@classmethod
def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool:
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
with open(path, "rb") as f:
magic_number = f.read(len(cls.RAR5_ID))
return magic_number == cls.RAR5_ID or magic_number.startswith(cls.RAR_ID)
class RarExtractor(MagicNumberBaseExtractor):
magic_numbers = [b"Rar!\x1a\x07\x00", b"Rar!\x1a\x07\x01\x00"] # RAR_ID # RAR5_ID

@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -184,7 +217,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None


class ZstdExtractor(MagicNumberBaseExtractor):
magic_number = b"\x28\xb5\x2F\xFD"
magic_numbers = [b"\x28\xb5\x2F\xFD"]

@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -198,7 +231,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None


class Bzip2Extractor(MagicNumberBaseExtractor):
magic_number = b"\x42\x5A\x68"
magic_numbers = [b"\x42\x5A\x68"]

@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -208,7 +241,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None


class SevenZipExtractor(MagicNumberBaseExtractor):
magic_number = b"\x37\x7A\xBC\xAF\x27\x1C"
magic_numbers = [b"\x37\x7A\xBC\xAF\x27\x1C"]

@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -222,7 +255,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None


class Lz4Extractor(MagicNumberBaseExtractor):
magic_number = b"\x04\x22\x4D\x18"
magic_numbers = [b"\x04\x22\x4D\x18"]

@staticmethod
def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None:
Expand All @@ -237,7 +270,7 @@ def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None

class Extractor:
# Put zip file to the last, b/c it is possible wrongly detected as zip (I guess it means: as tar or gzip)
extractors = {
extractors: Dict[str, Type[BaseExtractor]] = {
"tar": TarExtractor,
"gzip": GzipExtractor,
"zip": ZipExtractor,
Expand All @@ -251,14 +284,12 @@ class Extractor:

@classmethod
def _get_magic_number_max_length(cls):
magic_number_max_length = 0
for extractor in cls.extractors.values():
if hasattr(extractor, "magic_number"):
magic_number_length = len(extractor.magic_number)
magic_number_max_length = (
magic_number_length if magic_number_length > magic_number_max_length else magic_number_max_length
)
return magic_number_max_length
return max(
len(extractor_magic_number)
for extractor in cls.extractors.values()
if issubclass(extractor, MagicNumberBaseExtractor)
for extractor_magic_number in extractor.magic_numbers
)

@staticmethod
def _read_magic_number(path: Union[Path, str], magic_number_length: int):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import zipfile

import pytest

Expand Down Expand Up @@ -183,3 +184,20 @@ def test_tar_extract_insecure_files(
for record in caplog.records:
assert record.levelname == "ERROR"
assert error_log in record.msg


def test_is_zipfile_false_positive(tmpdir):
# We should have less false positives than zipfile.is_zipfile
# We do that by checking only the magic number
not_a_zip_file = tmpdir / "not_a_zip_file"
# From: https://github.com/python/cpython/pull/5053
data = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00"
b"\x00\x02\x08\x06\x00\x00\x00\x99\x81\xb6'\x00\x00\x00\x15I"
b"DATx\x01\x01\n\x00\xf5\xff\x00PK\x05\x06\x00PK\x06\x06\x07"
b"\xac\x01N\xc6|a\r\x00\x00\x00\x00IEND\xaeB`\x82"
)
with not_a_zip_file.open("wb") as f:
f.write(data)
assert zipfile.is_zipfile(str(not_a_zip_file)) # is a false positive for `zipfile`
assert not ZipExtractor.is_extractable(not_a_zip_file) # but we're right
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test passes because not None is True.