Skip to content

Commit 3f3586a

Browse files
CodyCBakerPhDrly
andauthored
Add base methods for iterator serialization (#924)
Co-authored-by: Ryan Ly <[email protected]>
1 parent ca7722f commit 3f3586a

File tree

4 files changed

+96
-22
lines changed

4 files changed

+96
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### New features and minor improvements
66
- Increase raw data chunk cache size for reading HDF5 files from 1 MiB to 32 MiB. @bendichter, @rly [#925](https://github.com/hdmf-dev/hdmf/pull/925)
77
- Increase default chunk size for `GenericDataChunkIterator` from 1 MB to 10 MB. @bendichter, @rly [#925](https://github.com/hdmf-dev/hdmf/pull/925)
8+
- Added the magic `__reduce__` method as well as two private semi-abstract helper methods to enable pickling of the `GenericDataChunkIterator`. @codycbakerphd [#924](https://github.com/hdmf-dev/hdmf/pull/924)
89

910
## HDMF 3.8.1 (July 25, 2023)
1011

src/hdmf/backends/io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ def read(self, **kwargs):
7474

7575
return container
7676

77-
@docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'},
78-
allow_extra=True)
77+
@docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, allow_extra=True)
7978
def write(self, **kwargs):
8079
"""Write a container to the IO source."""
8180
container = popargs('container', kwargs)

src/hdmf/data_utils.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABCMeta, abstractmethod
44
from collections.abc import Iterable
55
from warnings import warn
6-
from typing import Tuple
6+
from typing import Tuple, Callable
77
from itertools import product, chain
88

99
import h5py
@@ -190,9 +190,10 @@ def __init__(self, **kwargs):
190190
HDF5 recommends chunk size in the range of 2 to 16 MB for optimal cloud performance.
191191
https://youtu.be/rcS5vt-mKok?t=621
192192
"""
193-
buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, self.progress_bar_options = getargs(
193+
buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, progress_bar_options = getargs(
194194
"buffer_gb", "buffer_shape", "chunk_mb", "chunk_shape", "display_progress", "progress_bar_options", kwargs
195195
)
196+
self.progress_bar_options = progress_bar_options or dict()
196197

197198
if buffer_gb is None and buffer_shape is None:
198199
buffer_gb = 1.0
@@ -264,15 +265,13 @@ def __init__(self, **kwargs):
264265
)
265266

266267
if self.display_progress:
267-
if self.progress_bar_options is None:
268-
self.progress_bar_options = dict()
269-
270268
try:
271269
from tqdm import tqdm
272270

273271
if "total" in self.progress_bar_options:
274272
warn("Option 'total' in 'progress_bar_options' is not allowed to be over-written! Ignoring.")
275273
self.progress_bar_options.pop("total")
274+
276275
self.progress_bar = tqdm(total=self.num_buffers, **self.progress_bar_options)
277276
except ImportError:
278277
warn(
@@ -345,12 +344,6 @@ def _get_default_buffer_shape(self, **kwargs) -> Tuple[int, ...]:
345344
]
346345
)
347346

348-
def recommended_chunk_shape(self) -> Tuple[int, ...]:
349-
return self.chunk_shape
350-
351-
def recommended_data_shape(self) -> Tuple[int, ...]:
352-
return self.maxshape
353-
354347
def __iter__(self):
355348
return self
356349

@@ -371,6 +364,11 @@ def __next__(self):
371364
self.progress_bar.write("\n") # Allows text to be written to new lines after completion
372365
raise StopIteration
373366

367+
def __reduce__(self) -> Tuple[Callable, Iterable]:
368+
instance_constructor = self._from_dict
369+
initialization_args = (self._to_dict(),)
370+
return (instance_constructor, initialization_args)
371+
374372
@abstractmethod
375373
def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
376374
"""
@@ -391,24 +389,42 @@ def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
391389
"""
392390
raise NotImplementedError("The data fetching method has not been built for this DataChunkIterator!")
393391

394-
@property
395-
def maxshape(self) -> Tuple[int, ...]:
396-
return self._maxshape
397-
398392
@abstractmethod
399393
def _get_maxshape(self) -> Tuple[int, ...]:
400394
"""Retrieve the maximum bounds of the data shape using minimal I/O."""
401395
raise NotImplementedError("The setter for the maxshape property has not been built for this DataChunkIterator!")
402396

403-
@property
404-
def dtype(self) -> np.dtype:
405-
return self._dtype
406-
407397
@abstractmethod
408398
def _get_dtype(self) -> np.dtype:
409399
"""Retrieve the dtype of the data using minimal I/O."""
410400
raise NotImplementedError("The setter for the internal dtype has not been built for this DataChunkIterator!")
411401

402+
def _to_dict(self) -> dict:
403+
"""Optional method to add in child classes to enable pickling (required for multiprocessing)."""
404+
raise NotImplementedError(
405+
"The `._to_dict()` method for pickling has not been defined for this DataChunkIterator!"
406+
)
407+
408+
@staticmethod
409+
def _from_dict(self) -> Callable:
410+
"""Optional method to add in child classes to enable pickling (required for multiprocessing)."""
411+
raise NotImplementedError(
412+
"The `._from_dict()` method for pickling has not been defined for this DataChunkIterator!"
413+
)
414+
415+
def recommended_chunk_shape(self) -> Tuple[int, ...]:
416+
return self.chunk_shape
417+
418+
def recommended_data_shape(self) -> Tuple[int, ...]:
419+
return self.maxshape
420+
421+
@property
422+
def maxshape(self) -> Tuple[int, ...]:
423+
return self._maxshape
424+
@property
425+
def dtype(self) -> np.dtype:
426+
return self._dtype
427+
412428

413429
class DataChunkIterator(AbstractDataChunkIterator):
414430
"""

tests/unit/utils_test/test_core_GenericDataChunkIterator.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import unittest
2+
import pickle
23
import numpy as np
34
from pathlib import Path
45
from tempfile import mkdtemp
56
from shutil import rmtree
6-
from typing import Tuple, Iterable
7+
from typing import Tuple, Iterable, Callable
78
from sys import version_info
89

910
import h5py
11+
from numpy.testing import assert_array_equal
1012

1113
from hdmf.data_utils import GenericDataChunkIterator
1214
from hdmf.testing import TestCase
@@ -18,6 +20,30 @@
1820
TQDM_INSTALLED = False
1921

2022

23+
class TestPickleableNumpyArrayDataChunkIterator(GenericDataChunkIterator):
24+
def __init__(self, array: np.ndarray, **kwargs):
25+
self.array = array
26+
self._kwargs = kwargs
27+
super().__init__(**kwargs)
28+
29+
def _get_data(self, selection) -> np.ndarray:
30+
return self.array[selection]
31+
32+
def _get_maxshape(self) -> Tuple[int, ...]:
33+
return self.array.shape
34+
35+
def _get_dtype(self) -> np.dtype:
36+
return self.array.dtype
37+
38+
def _to_dict(self) -> dict:
39+
return dict(array=pickle.dumps(self.array), kwargs=self._kwargs)
40+
41+
@staticmethod
42+
def _from_dict(dictionary: dict) -> Callable:
43+
array = pickle.loads(dictionary["array"])
44+
return TestPickleableNumpyArrayDataChunkIterator(array=array, **dictionary["kwargs"])
45+
46+
2147
class GenericDataChunkIteratorTests(TestCase):
2248
class TestNumpyArrayDataChunkIterator(GenericDataChunkIterator):
2349
def __init__(self, array: np.ndarray, **kwargs):
@@ -204,6 +230,29 @@ def test_progress_bar_assertion(self):
204230
progress_bar_options=dict(total=5),
205231
)
206232

233+
def test_private_to_dict_assertion(self):
234+
with self.assertRaisesWith(
235+
exc_type=NotImplementedError,
236+
exc_msg="The `._to_dict()` method for pickling has not been defined for this DataChunkIterator!"
237+
):
238+
iterator = self.TestNumpyArrayDataChunkIterator(array=self.test_array)
239+
_ = iterator._to_dict()
240+
241+
def test_private_from_dict_assertion(self):
242+
with self.assertRaisesWith(
243+
exc_type=NotImplementedError,
244+
exc_msg="The `._from_dict()` method for pickling has not been defined for this DataChunkIterator!"
245+
):
246+
_ = self.TestNumpyArrayDataChunkIterator._from_dict(dict())
247+
248+
def test_direct_pickle_assertion(self):
249+
with self.assertRaisesWith(
250+
exc_type=NotImplementedError,
251+
exc_msg="The `._to_dict()` method for pickling has not been defined for this DataChunkIterator!"
252+
):
253+
iterator = self.TestNumpyArrayDataChunkIterator(array=self.test_array)
254+
_ = pickle.dumps(iterator)
255+
207256
def test_maxshape_attribute_contains_int_type(self):
208257
"""Motivated by issues described in https://github.com/hdmf-dev/hdmf/pull/780 & 781 regarding return types."""
209258
self.check_all_of_iterable_is_python_int(
@@ -377,3 +426,12 @@ def test_tqdm_not_installed(self):
377426
display_progress=True,
378427
)
379428
self.assertFalse(dci.display_progress)
429+
430+
def test_pickle(self):
431+
pre_dump_iterator = TestPickleableNumpyArrayDataChunkIterator(array=self.test_array)
432+
post_dump_iterator = pickle.loads(pickle.dumps(pre_dump_iterator))
433+
434+
assert isinstance(post_dump_iterator, TestPickleableNumpyArrayDataChunkIterator)
435+
assert post_dump_iterator.chunk_shape == pre_dump_iterator.chunk_shape
436+
assert post_dump_iterator.buffer_shape == pre_dump_iterator.buffer_shape
437+
assert_array_equal(post_dump_iterator.array, pre_dump_iterator.array)

0 commit comments

Comments
 (0)