Skip to content

Commit c80645d

Browse files
committed
Test if too big blocks get written correctly
1 parent 39d31e0 commit c80645d

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

src/isal/igzip_threaded.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,10 @@ def open(filename, mode="rb", compresslevel=igzip._COMPRESS_LEVEL_TRADEOFF,
6767
gzip_file = io.BufferedReader(
6868
_ThreadedGzipReader(binary_file, block_size=block_size))
6969
else:
70-
# Deflating random data results in an output a little larger than the
71-
# input. Making the output buffer 10% larger is sufficient overkill.
72-
compress_buffer_size = block_size + max(
73-
block_size // 10, 500)
7470
gzip_file = io.BufferedWriter(
7571
_ThreadedGzipWriter(
7672
fp=binary_file,
77-
buffer_size=compress_buffer_size,
73+
block_size=block_size,
7874
level=compresslevel,
7975
threads=threads
8076
),
@@ -201,15 +197,19 @@ def __init__(self,
201197
level: int = isal_zlib.ISAL_DEFAULT_COMPRESSION,
202198
threads: int = 1,
203199
queue_size: int = 1,
204-
buffer_size: int = 1024 * 1024,
200+
block_size: int = 1024 * 1024,
205201
):
206202
self.lock = threading.Lock()
207203
self.exception: Optional[Exception] = None
208204
self.raw = fp
209205
self.level = level
210206
self.previous_block = b""
207+
# Deflating random data results in an output a little larger than the
208+
# input. Making the output buffer 10% larger is sufficient overkill.
209+
compress_buffer_size = block_size + max(block_size // 10, 500)
210+
self.block_size = block_size
211211
self.compressors: List[isal_zlib._ParallelCompress] = [
212-
isal_zlib._ParallelCompress(buffersize=buffer_size,
212+
isal_zlib._ParallelCompress(buffersize=compress_buffer_size,
213213
level=level) for _ in range(threads)
214214
]
215215
if threads > 1:
@@ -273,8 +273,19 @@ def write(self, b) -> int:
273273
with self.lock:
274274
if self.exception:
275275
raise self.exception
276-
index = self.index
276+
length = b.nbytes if isinstance(b, memoryview) else len(b)
277+
if length > self.block_size:
278+
# write smaller chunks and return the result
279+
memview = memoryview(b)
280+
start = 0
281+
total_written = 0
282+
while start < length:
283+
total_written += self.write(
284+
memview[start:start+self.block_size])
285+
start += self.block_size
286+
return total_written
277287
data = bytes(b)
288+
index = self.index
278289
zdict = memoryview(self.previous_block)[-DEFLATE_WINDOW_SIZE:]
279290
self.previous_block = data
280291
self.index += 1

tests/test_igzip_threaded.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,29 @@ def test_threaded_read_error():
7878
@pytest.mark.timeout(5)
7979
@pytest.mark.parametrize("threads", [1, 3])
8080
def test_threaded_write_oversized_block_no_error(threads):
81-
with igzip_threaded.open(
82-
io.BytesIO(), "wb", compresslevel=3, threads=threads,
83-
block_size=8 * 1024
84-
) as writer:
85-
writer.write(os.urandom(1024 * 64))
81+
# Random bytes are incompressible, and therefore are guaranteed to
82+
# trigger a buffer overflow when larger than block size unless handled
83+
# correctly.
84+
data = os.urandom(1024 * 63) # not a multiple of block_size
85+
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tmp:
86+
with igzip_threaded.open(
87+
tmp, "wb", compresslevel=3, threads=threads,
88+
block_size=8 * 1024
89+
) as writer:
90+
writer.write(data)
91+
with gzip.open(tmp.name, "rb") as gzipped:
92+
decompressed = gzipped.read()
93+
assert data == decompressed
8694

8795

8896
@pytest.mark.timeout(5)
8997
@pytest.mark.parametrize("threads", [1, 3])
9098
def test_threaded_write_error(threads):
9199
f = igzip_threaded._ThreadedGzipWriter(
92100
fp=io.BytesIO(), level=3,
93-
threads=threads, buffer_size=8 * 1024)
94-
# Bypass the write method which should not allow this.
101+
threads=threads, block_size=8 * 1024)
102+
# Bypass the write method which should not allow blocks larger than
103+
# block_size.
95104
f.input_queues[0].put((os.urandom(1024 * 64), b""))
96105
with pytest.raises(OverflowError) as error:
97106
f.close()

0 commit comments

Comments
 (0)