Skip to content

Commit 5406382

Browse files
committed
zstdfile: ensure we do not read more than size / IO_BLOCK_SIZE
In the previous implementation, _ZtsdFileReader.read could produce output of arbirary size. This can cause memory spikes while decompressing a file. Instead, we should use a ZstdDecompressor.stream_reader which decompresses incrementally into a fixed size output buffer.
1 parent b9e21b2 commit 5406382

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

rohmu/zstdfile.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,31 +43,28 @@ def writable(self) -> bool:
4343

4444
class _ZtsdFileReader(FileWrap):
4545
def __init__(self, next_fp: FileLike) -> None:
46-
self._zstd = zstd.ZstdDecompressor().decompressobj()
46+
self._stream = zstd.ZstdDecompressor().stream_reader(
47+
next_fp, # type: ignore[arg-type]
48+
read_size=IO_BLOCK_SIZE,
49+
read_across_frames=True,
50+
)
4751
super().__init__(next_fp)
48-
self._done = False
4952

5053
def close(self) -> None:
5154
if self.closed:
5255
return
56+
self._stream.close()
5357
super().close()
5458

5559
def read(self, size: Optional[int] = -1) -> bytes:
56-
# NOTE: size arg is ignored, random size output is returned
5760
self._check_not_closed()
58-
while not self._done:
59-
compressed = self.next_fp.read(IO_BLOCK_SIZE)
60-
if not compressed:
61-
self._done = True
62-
output = self._zstd.flush() or b""
63-
else:
64-
output = self._zstd.decompress(compressed)
65-
66-
if output:
67-
self.offset += len(output)
68-
return output
69-
70-
return b""
61+
if size == 0:
62+
return b""
63+
64+
read_size = size if size and size > 0 else IO_BLOCK_SIZE
65+
data = self._stream.read(read_size)
66+
self.offset += len(data)
67+
return data
7168

7269
def readable(self) -> bool:
7370
return True

test/test_zstdfile.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2025 Aiven, Helsinki, Finland. https://aiven.io/
2+
# See LICENSE for details
3+
from rohmu import zstdfile
4+
5+
import io
6+
7+
8+
def test_compress_and_decompress() -> None:
9+
"""Test basic compression and decompression"""
10+
original_data = b"Hello, World! " * 10_000
11+
12+
compressed_buffer = io.BytesIO()
13+
with zstdfile.open(compressed_buffer, "wb", level=3) as zf:
14+
written = zf.write(original_data)
15+
assert written == len(original_data)
16+
17+
compressed_buffer.seek(0)
18+
decompressed_data = b""
19+
with zstdfile.open(compressed_buffer, "rb") as zf:
20+
chunk = zf.read(512)
21+
assert len(chunk) <= 512
22+
while chunk:
23+
decompressed_data += chunk
24+
chunk = zf.read(512)
25+
26+
assert decompressed_data == original_data

0 commit comments

Comments
 (0)