|
17 | 17 | class Loader(abc.ABC): |
18 | 18 | @abc.abstractmethod |
19 | 19 | def make_request(self, item: typing.Any) -> typing.Any: |
| 20 | + """Called to make data loading request to potentially passing to the workers. Ideally the return value should |
| 21 | + be easily pickable otherwise it might be very slow. |
| 22 | +
|
| 23 | + :param item: The item to generate the loading request for |
| 24 | + :return: a pickable value for the worker process or the current process to load |
| 25 | + """ |
20 | 26 | raise NotImplementedError |
21 | 27 |
|
22 | 28 | @abc.abstractmethod |
23 | 29 | def load(self, request: typing.Any) -> tuple[np.typing.NDArray, ...]: |
| 30 | + """Called to load data for the given item. Potentially called from a worker process. |
| 31 | +
|
| 32 | + :param request: Request for loading the data |
| 33 | + :return: The loaded data, should be a tuple of numpy's ndarray |
| 34 | + """ |
24 | 35 | raise NotImplementedError |
25 | 36 |
|
26 | 37 | @abc.abstractmethod |
27 | 38 | def post_process( |
28 | 39 | self, response: tuple[np.typing.NDArray, ...] |
29 | 40 | ) -> tuple[tinygrad.Tensor, ...]: |
| 41 | + """Called to convert numpy's ndarray returned from the `load` method into tinygrad's Tensor for training or |
| 42 | + testing purpose. This method will be called from the process which invokes the loading generator. |
| 43 | +
|
| 44 | + :param response: Response ndarray values returned by the `load` method |
| 45 | + :return: A tuple of tinygrad Tensor for training / testing or other purpose |
| 46 | + """ |
30 | 47 | raise NotImplementedError |
31 | 48 |
|
| 49 | + def shutdown(self): |
| 50 | + """Called to shutdown resources associated with the loader. Like, abort async operations or release files and |
| 51 | + etc. |
| 52 | +
|
| 53 | + """ |
| 54 | + |
32 | 55 |
|
33 | 56 | @dataclasses.dataclass(frozen=True) |
34 | 57 | class SharedBuffer: |
@@ -112,6 +135,9 @@ def push(self, shared_buffer: SharedBuffer): |
112 | 135 | self._queue.put(shared_buffer.index) |
113 | 136 | logger.debug("Pushed shared buffer %s", shared_buffer) |
114 | 137 |
|
| 138 | + def shutdown(self): |
| 139 | + self._queue.shutdown(immediate=True) |
| 140 | + |
115 | 141 |
|
116 | 142 | class SharedMemoryShim(Loader): |
117 | 143 | def __init__( |
@@ -181,6 +207,10 @@ def __reduce__(self): |
181 | 207 | # avoid pickling SharedMemoryManager, only care about the underlying loader in `load` method anyway |
182 | 208 | return self.__class__, (self.loader, None, 0), None |
183 | 209 |
|
| 210 | + def shutdown(self): |
| 211 | + for mem_pool in self._memory_pools.values(): |
| 212 | + mem_pool.shutdown() |
| 213 | + |
184 | 214 |
|
185 | 215 | def load( |
186 | 216 | loader: Loader, items: typing.Sequence[typing.Any] |
@@ -244,4 +274,7 @@ def generate() -> typing.Generator[tuple[tinygrad.Tensor, ...], None, None]: |
244 | 274 | ), |
245 | 275 | ) |
246 | 276 |
|
247 | | - yield generate() |
| 277 | + try: |
| 278 | + yield generate() |
| 279 | + finally: |
| 280 | + actual_loader.shutdown() |
0 commit comments