Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 86 additions & 8 deletions src/crystal/util/netzipfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@

from __future__ import annotations

import bz2
import io
import lzma
import struct
from typing import BinaryIO, Literal, Protocol
import zlib

try:
import compression.zstd as _zstd # Python 3.14+
except ImportError:
_zstd = None # type: ignore[assignment]


# === Public API ===

Expand Down Expand Up @@ -61,11 +68,13 @@ def open(self, entry_name: str, mode: Literal['r'] = 'r') -> BinaryIO:
"""
Opens the named entry for reading.

Handles both stored (uncompressed) and deflate-compressed entries.
Handles both stored (uncompressed) and compressed entries.

Raises:
* KeyError -- if entry_name is not found in the zip file.
* ValueError -- if the entry uses an unsupported compression method.
* RuntimeError -- if the entry uses a compression method whose
supporting module is not available (e.g. zstandard on Python <3.14).
"""
entry = self._entries.get(entry_name)
if entry is None:
Expand Down Expand Up @@ -359,8 +368,34 @@ def _open_entry_data(open_range: OpenRangeCallable, entry: _CdEntry) -> BinaryIO
)
elif entry.compression_method == 8:
return io.BufferedReader(
_DeflateReader(
_LimitedReader(stream, entry.compressed_size) # type: ignore[arg-type]
_CompressedReader(
_LimitedReader(stream, entry.compressed_size), # type: ignore[arg-type]
zlib.decompressobj(-15) # wbits=-15: raw deflate (no zlib header), as used in ZIP
)
)
elif entry.compression_method == 12:
return io.BufferedReader(
_CompressedReader(
_LimitedReader(stream, entry.compressed_size), # type: ignore[arg-type]
bz2.BZ2Decompressor()
)
)
elif entry.compression_method == 14:
return io.BufferedReader(
_CompressedReader(
_LimitedReader(stream, entry.compressed_size), # type: ignore[arg-type]
_LzmaZipDecompressor()
)
)
elif entry.compression_method == 93:
if _zstd is None:
raise RuntimeError(
'Zstandard compression requires Python 3.14+ (compression.zstd module)'
)
return io.BufferedReader(
_CompressedReader(
_LimitedReader(stream, entry.compressed_size), # type: ignore[arg-type]
_zstd.ZstdDecompressor()
)
)
else:
Expand Down Expand Up @@ -392,14 +427,17 @@ def readinto(self, b: bytearray | memoryview) -> int: # type: ignore[override]
return actual


class _DeflateReader(io.RawIOBase):
class _Decompressor(Protocol):
def decompress(self, data: bytes) -> bytes: ...


class _CompressedReader(io.RawIOBase):
# Chunk size for reading compressed data from the underlying stream
_READ_CHUNK = 65536

def __init__(self, raw: BinaryIO) -> None:
def __init__(self, raw: BinaryIO, decompressor: _Decompressor) -> None:
self._raw = raw
# wbits=-15: raw deflate (no zlib header), as used in ZIP
self._decompressor = zlib.decompressobj(-15)
self._decompressor = decompressor
self._buf = bytearray() # decompressed bytes not yet consumed by caller
self._buf_pos = 0 # read cursor into _buf; advance instead of slicing
self._done = False # True once decompressor is exhausted
Expand All @@ -424,7 +462,9 @@ def readinto(self, b: bytearray | memoryview) -> int: # type: ignore[override]
self._buf.extend(self._decompressor.decompress(chunk))
else:
# No more compressed data; flush any remaining decompressed bytes
self._buf.extend(self._decompressor.flush())
# (e.g. zlib.decompressobj requires an explicit flush call)
if hasattr(self._decompressor, 'flush'):
self._buf.extend(self._decompressor.flush()) # type: ignore[attr-defined]
self._done = True

# Copy min(want, len(self._buf)) bytes from the start of self._buf to b.
Expand All @@ -438,3 +478,41 @@ def readinto(self, b: bytearray | memoryview) -> int: # type: ignore[override]
b[:] = memoryview(self._buf)[:want]
self._buf_pos = want
return want


class _LzmaZipDecompressor:
"""
Handles the ZIP LZMA stream format, which prepends a 4-byte header
(2 bytes version + 2 bytes properties size) and the LZMA properties
before the raw LZMA compressed data.

Mirrors the LZMADecompressor class used in Python's own zipfile module.
"""

def __init__(self) -> None:
self._decomp: lzma.LZMADecompressor | None = None
self._unconsumed = b''

def decompress(self, data: bytes) -> bytes:
if self._decomp is None:
self._unconsumed += data
if len(self._unconsumed) <= 4:
return b''
(psize,) = struct.unpack('<H', self._unconsumed[2:4])
if len(self._unconsumed) <= 4 + psize:
return b''
# NOTE: lzma._decode_filter_properties is a private CPython function
# used here following the same pattern as Python's own zipfile module
# (see Lib/zipfile/__init__.py LZMADecompressor class).
self._decomp = lzma.LZMADecompressor(
lzma.FORMAT_RAW,
filters=[lzma._decode_filter_properties( # type: ignore[attr-defined]
lzma.FILTER_LZMA1, self._unconsumed[4:4 + psize]
)],
)
data = self._unconsumed[4 + psize:]
self._unconsumed = b'' # release buffered header bytes
return self._decomp.decompress(data)



60 changes: 60 additions & 0 deletions tests/test_netzipfile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit tests for netzipfile module."""

import sys
from typing import BinaryIO

from crystal.util import netzipfile
Expand Down Expand Up @@ -40,6 +41,37 @@ def test_can_read_entry_given_deflate_compressed_zip() -> None:
assert nzf.open('compressed.txt').read() == content


def test_can_read_entry_given_bzip2_compressed_zip() -> None:
content = b'The quick brown fox jumps over the lazy dog' * 100
entries = {'compressed.txt': content}
data = _create_bzip2_zip(entries)
nzf = NetZipFile(_create_open_range_func(data))

assert nzf.open('compressed.txt').read() == content


def test_can_read_entry_given_lzma_compressed_zip() -> None:
content = b'The quick brown fox jumps over the lazy dog' * 100
entries = {'compressed.txt': content}
data = _create_lzma_zip(entries)
nzf = NetZipFile(_create_open_range_func(data))

assert nzf.open('compressed.txt').read() == content


@pytest.mark.skipif(
sys.version_info < (3, 14),
reason='compression.zstd is only available in Python 3.14+',
)
def test_can_read_entry_given_zstd_compressed_zip() -> None:
content = b'The quick brown fox jumps over the lazy dog' * 100
entries = {'compressed.txt': content}
data = _create_zstd_zip(entries)
nzf = NetZipFile(_create_open_range_func(data))

assert nzf.open('compressed.txt').read() == content


def test_raises_key_error_when_try_read_missing_entry() -> None:
entries = {'000': b'Alpha'}
data = _create_zip64_stored_zip(entries)
Expand Down Expand Up @@ -211,6 +243,34 @@ def _create_deflate_zip(entries: dict[str, bytes]) -> bytes:
return buf.getvalue()


def _create_bzip2_zip(entries: dict[str, bytes]) -> bytes:
"""Build a bzip2-compressed zip file in memory."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, 'w', compression=zipfile.ZIP_BZIP2, allowZip64=False) as zf:
for (name, data) in entries.items():
zf.writestr(name, data)
return buf.getvalue()


def _create_lzma_zip(entries: dict[str, bytes]) -> bytes:
"""Build an LZMA-compressed zip file in memory."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, 'w', compression=zipfile.ZIP_LZMA, allowZip64=False) as zf:
for (name, data) in entries.items():
zf.writestr(name, data)
return buf.getvalue()


def _create_zstd_zip(entries: dict[str, bytes]) -> bytes:
"""Build a zstandard-compressed zip file in memory."""
buf = io.BytesIO()
compression: int = zipfile.ZIP_ZSTANDARD # type: ignore[attr-defined]
with zipfile.ZipFile(buf, 'w', compression=compression, allowZip64=False) as zf:
for (name, data) in entries.items():
zf.writestr(name, data)
return buf.getvalue()


def _create_open_range_func(data: bytes) -> OpenRangeCallable:
"""
Create an OpenRangeCallable that reads from an in-memory bytes buffer.
Expand Down
Loading