Skip to content

Commit 02f6d4f

Browse files
committed
Fix shared memory queue not shutdown issue and cause load_with_workers hangs forever when exits issue
1 parent 33e42b6 commit 02f6d4f

File tree

4 files changed

+70
-2
lines changed

4 files changed

+70
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111

1212
[dependency-groups]
1313
dev = [
14+
"pytest-timeout>=2.4.0",
1415
"pytest>=8.4.0",
1516
"tqdm>=4.67.1",
1617
]

tests/test_loader.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from multiprocessing.managers import SharedMemoryManager
33

44
import numpy as np
5+
import pytest
56
import tinygrad
67
import tqdm
78

@@ -96,3 +97,22 @@ def test_share_memory_enabled():
9697
assert y.numpy().shape == label_size
9798
count += 1
9899
assert count == n
100+
101+
102+
@pytest.mark.timeout(10)
103+
def test_generator_early_stops_queue_not_shutdown():
104+
data_size = (5,)
105+
label_size = (4,)
106+
num_worker = 4
107+
108+
def forever_gen():
109+
while True:
110+
yield 1
111+
112+
loader = RandomLoader(data_size=data_size, label_size=label_size)
113+
with load_with_workers(
114+
loader, forever_gen(), num_worker, shared_memory_enabled=True
115+
) as generator:
116+
for i, _ in enumerate(tqdm.tqdm(generator)):
117+
if i > 10:
118+
break

tinyloader/loader.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,41 @@
1717
class Loader(abc.ABC):
1818
@abc.abstractmethod
1919
def make_request(self, item: typing.Any) -> typing.Any:
20+
"""Called to make data loading request to potentially passing to the workers. Ideally the return value should
21+
be easily pickable otherwise it might be very slow.
22+
23+
:param item: The item to generate the loading request for
24+
:return: a pickable value for the worker process or the current process to load
25+
"""
2026
raise NotImplementedError
2127

2228
@abc.abstractmethod
2329
def load(self, request: typing.Any) -> tuple[np.typing.NDArray, ...]:
30+
"""Called to load data for the given item. Potentially called from a worker process.
31+
32+
:param request: Request for loading the data
33+
:return: The loaded data, should be a tuple of numpy's ndarray
34+
"""
2435
raise NotImplementedError
2536

2637
@abc.abstractmethod
2738
def post_process(
2839
self, response: tuple[np.typing.NDArray, ...]
2940
) -> tuple[tinygrad.Tensor, ...]:
41+
"""Called to convert numpy's ndarray returned from the `load` method into tinygrad's Tensor for training or
42+
testing purpose. This method will be called from the process which invokes the loading generator.
43+
44+
:param response: Response ndarray values returned by the `load` method
45+
:return: A tuple of tinygrad Tensor for training / testing or other purpose
46+
"""
3047
raise NotImplementedError
3148

49+
def shutdown(self):
50+
"""Called to shutdown resources associated with the loader. Like, abort async operations or release files and
51+
etc.
52+
53+
"""
54+
3255

3356
@dataclasses.dataclass(frozen=True)
3457
class SharedBuffer:
@@ -112,6 +135,9 @@ def push(self, shared_buffer: SharedBuffer):
112135
self._queue.put(shared_buffer.index)
113136
logger.debug("Pushed shared buffer %s", shared_buffer)
114137

138+
def shutdown(self):
139+
self._queue.shutdown(immediate=True)
140+
115141

116142
class SharedMemoryShim(Loader):
117143
def __init__(
@@ -181,6 +207,10 @@ def __reduce__(self):
181207
# avoid pickling SharedMemoryManager, only care about the underlying loader in `load` method anyway
182208
return self.__class__, (self.loader, None, 0), None
183209

210+
def shutdown(self):
211+
for mem_pool in self._memory_pools.values():
212+
mem_pool.shutdown()
213+
184214

185215
def load(
186216
loader: Loader, items: typing.Sequence[typing.Any]
@@ -244,4 +274,7 @@ def generate() -> typing.Generator[tuple[tinygrad.Tensor, ...], None, None]:
244274
),
245275
)
246276

247-
yield generate()
277+
try:
278+
yield generate()
279+
finally:
280+
actual_loader.shutdown()

uv.lock

Lines changed: 15 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)