Skip to content

Commit b43c488

Browse files
committed
Better shutdown handling
1 parent 6cfb481 commit b43c488

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

tinyloader/loader.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import multiprocessing
66
import queue
7+
import signal
78
import typing
89
from multiprocessing.managers import SharedMemoryManager
910
from multiprocessing.shared_memory import SharedMemory
@@ -226,6 +227,11 @@ def load(
226227
)
227228

228229

230+
# ref: https://stackoverflow.com/a/6191991
231+
def init_worker():
232+
signal.signal(signal.SIGINT, signal.SIG_IGN)
233+
234+
229235
@contextlib.contextmanager
230236
def load_with_workers(
231237
loader: Loader,
@@ -250,7 +256,10 @@ def load_with_workers(
250256
shared_memory_ctx = contextlib.nullcontext()
251257
if shared_memory_enabled:
252258
shared_memory_ctx = SharedMemoryManager()
253-
with shared_memory_ctx as smm, multiprocessing.Pool(num_worker) as pool:
259+
with (
260+
shared_memory_ctx as smm,
261+
multiprocessing.Pool(num_worker, init_worker) as pool,
262+
):
254263
items_iter = iter(items)
255264

256265
actual_loader = loader
@@ -276,5 +285,12 @@ def generate() -> typing.Generator[tuple[tinygrad.Tensor, ...], None, None]:
276285

277286
try:
278287
yield generate()
288+
except KeyboardInterrupt:
289+
actual_loader.shutdown()
290+
pool.terminate()
291+
pool.join()
292+
raise
279293
finally:
280294
actual_loader.shutdown()
295+
pool.close()
296+
pool.join()

0 commit comments

Comments
 (0)