Skip to content

Commit c8f5835

Browse files
committed
zstdfile: ensure we do not read more than size / IO_BLOCK_SIZE
In the previous implementation, _ZtsdFileReader.read could produce output of arbirary size. This can cause memory spikes while decompressing a file. Instead, we should use a ZstdDecompressor.stream_reader which decompresses incrementally into a fixed size output buffer.
1 parent b9e21b2 commit c8f5835

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

rohmu/zstdfile.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,31 +43,26 @@ def writable(self) -> bool:
4343

4444
class _ZtsdFileReader(FileWrap):
4545
def __init__(self, next_fp: FileLike) -> None:
46-
self._zstd = zstd.ZstdDecompressor().decompressobj()
46+
self._stream = zstd.ZstdDecompressor().stream_reader(
47+
next_fp,
48+
read_size=IO_BLOCK_SIZE,
49+
read_across_frames=True,
50+
)
4751
super().__init__(next_fp)
4852
self._done = False
4953

5054
def close(self) -> None:
5155
if self.closed:
5256
return
57+
self._stream.close()
5358
super().close()
5459

5560
def read(self, size: Optional[int] = -1) -> bytes:
56-
# NOTE: size arg is ignored, random size output is returned
5761
self._check_not_closed()
58-
while not self._done:
59-
compressed = self.next_fp.read(IO_BLOCK_SIZE)
60-
if not compressed:
61-
self._done = True
62-
output = self._zstd.flush() or b""
63-
else:
64-
output = self._zstd.decompress(compressed)
65-
66-
if output:
67-
self.offset += len(output)
68-
return output
69-
70-
return b""
62+
read_size = size if size and size > 0 else IO_BLOCK_SIZE
63+
data = self._stream.read(read_size)
64+
self.offset += len(data)
65+
return data
7166

7267
def readable(self) -> bool:
7368
return True

test/test_zstdfile.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2025 Aiven, Helsinki, Finland. https://aiven.io/
2+
# See LICENSE for details
3+
from rohmu import zstdfile
4+
5+
import io
6+
7+
8+
def test_compress_and_decompress() -> None:
9+
"""Test basic compression and decompression"""
10+
original_data = b"Hello, World! " * 10_000
11+
12+
compressed_buffer = io.BytesIO()
13+
with zstdfile.open(compressed_buffer, "wb", level=3) as zf:
14+
written = zf.write(original_data)
15+
assert written == len(original_data)
16+
17+
compressed_buffer.seek(0)
18+
decompressed_data = b""
19+
with zstdfile.open(compressed_buffer, "rb") as zf:
20+
chunk = zf.read(512)
21+
assert len(chunk) <= 512
22+
while chunk:
23+
decompressed_data += chunk
24+
chunk = zf.read(512)
25+
26+
assert decompressed_data == original_data

0 commit comments

Comments
 (0)