diff --git a/src/zlib_ng/gzip_ng_threaded.py b/src/zlib_ng/gzip_ng_threaded.py index 8a7e54a..a4cdfa4 100644 --- a/src/zlib_ng/gzip_ng_threaded.py +++ b/src/zlib_ng/gzip_ng_threaded.py @@ -11,6 +11,7 @@ import os import queue import struct +import sys import threading from typing import List, Optional, Tuple @@ -98,9 +99,11 @@ def __init__(self, filename, queue_size=2, block_size=1024 * 1024): self.exception = None self.buffer = io.BytesIO() self.block_size = block_size - self.worker = threading.Thread(target=self._decompress) + # Using a daemon thread prevents programs freezing on error. + self.worker = threading.Thread(target=self._decompress, daemon=True) self._closed = False - self.running = False + self.running = True + self.worker.start() def _check_closed(self, msg=None): if self._closed: @@ -122,21 +125,11 @@ def _decompress(self): block_queue.put(data, timeout=0.05) break except queue.Full: - pass - - def _start(self): - if not self.running: - self.running = True - self.worker.start() - - def _stop(self): - if self.running: - self.running = False - self.worker.join() + if sys.is_finalizing(): + return def readinto(self, b): self._check_closed() - self._start() result = self.buffer.readinto(b) if result == 0: while True: @@ -164,7 +157,8 @@ def tell(self) -> int: def close(self) -> None: if self._closed: return - self._stop() + self.running = False + self.worker.join() self.fileobj.close() if self.closefd: self.raw.close() @@ -240,9 +234,10 @@ def __init__(self, queue.Queue(queue_size) for _ in range(threads)] self.output_queues: List[queue.Queue[Tuple[bytes, int, int]]] = [ queue.Queue(queue_size) for _ in range(threads)] - self.output_worker = threading.Thread(target=self._write) + # Using daemon threads prevents a program freezing on error. + self.output_worker = threading.Thread(target=self._write, daemon=True) self.compression_workers = [ - threading.Thread(target=self._compress, args=(i,)) + threading.Thread(target=self._compress, args=(i,), daemon=True) for i in range(threads) ] elif threads == 1: @@ -250,7 +245,7 @@ def __init__(self, self.output_queues = [] self.compression_workers = [] self.output_worker = threading.Thread( - target=self._compress_and_write) + target=self._compress_and_write, daemon=True) else: raise ValueError(f"threads should be at least 1, got {threads}") self.threads = threads @@ -261,6 +256,7 @@ def __init__(self, self.raw, self.closefd = open_as_binary_stream(filename, mode) self._closed = False self._write_gzip_header() + self.start() def _check_closed(self, msg=None): if self._closed: @@ -283,24 +279,21 @@ def _write_gzip_header(self): self.raw.write(struct.pack( "BBBBIBB", magic1, magic2, method, flags, mtime, os, xfl)) - def _start(self): - if not self.running: - self.running = True - self.output_worker.start() - for worker in self.compression_workers: - worker.start() + def start(self): + self.running = True + self.output_worker.start() + for worker in self.compression_workers: + worker.start() def stop(self): """Stop, but do not care for remaining work""" - if self.running: - self.running = False - for worker in self.compression_workers: - worker.join() - self.output_worker.join() + self.running = False + for worker in self.compression_workers: + worker.join() + self.output_worker.join() def write(self, b) -> int: self._check_closed() - self._start() with self.lock: if self.exception: raise self.exception diff --git a/tests/test_gzip_ng_threaded.py b/tests/test_gzip_ng_threaded.py index 1a0a5a8..7ed06de 100644 --- a/tests/test_gzip_ng_threaded.py +++ b/tests/test_gzip_ng_threaded.py @@ -105,7 +105,6 @@ def test_threaded_write_error(threads): threads=threads, block_size=8 * 1024) # Bypass the write method which should not allow blocks larger than # block_size. - f._start() f.input_queues[0].put((os.urandom(1024 * 64), b"")) with pytest.raises(OverflowError) as error: f.close()