Skip to content

Commit 9cec5f6

Browse files
Thomas ZilioThomas-Z
authored andcommitted
fix: Include dtype information in the Partitioning configuration.
1 parent a84fae8 commit 9cec5f6

File tree

4 files changed

+57
-18
lines changed

4 files changed

+57
-18
lines changed

zcollection/partitioning/abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def unique_and_check_monotony(arr: ArrayLike) -> tuple[NDArray, NDArray]:
106106
107107
Args:
108108
arr: Array of elements.
109-
is_delayed: If True, the array is delayed.
110109
Returns:
111110
Tuple of unique elements and their indices.
112111
"""
@@ -331,12 +330,13 @@ def get_config(self) -> dict[str, Any]:
331330
Returns:
332331
The configuration of the partitioning scheme.
333332
"""
334-
config: dict[str, str | None] = {'id': self.ID}
333+
config: dict[str, str | tuple[str, ...] | None] = {'id': self.ID}
335334
slots: Generator[tuple[str, ...]] = (getattr(
336335
_class, '__slots__',
337336
()) for _class in reversed(self.__class__.__mro__))
338337
config.update((attr, getattr(self, attr)) for _class in slots
339338
for attr in _class if not attr.startswith('_'))
339+
config['dtype'] = self._dtype
340340
return config
341341

342342
@classmethod

zcollection/partitioning/date.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,10 @@ def decode(
217217
py_datetime: datetime.datetime = datetime64.astype('M8[s]').item()
218218
return tuple((UNITS[ix], getattr(py_datetime, self._attrs[ix]))
219219
for ix in self._index)
220+
221+
def get_config(self) -> dict[str, Any]:
222+
config = super().get_config()
223+
224+
# dtype are automatically computed by this partitioning
225+
config.pop('dtype')
226+
return config

zcollection/partitioning/tests/test_date.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,23 +166,43 @@ def test_construction() -> None:
166166
Date(('dates', ), 'W')
167167

168168

169-
def test_config():
169+
RESOLUTION_DTYPE_TEST_SET = [
170+
('Y', (('year', 'uint16'), )),
171+
('M', (('year', 'uint16'), ('month', 'uint8'))),
172+
('D', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'))),
173+
('h', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'),
174+
('hour', 'uint8'))),
175+
('m', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'),
176+
('hour', 'uint8'), ('minute', 'uint8'))),
177+
('s', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'),
178+
('hour', 'uint8'), ('minute', 'uint8'), ('second', 'uint8')))
179+
]
180+
181+
182+
@pytest.mark.parametrize('resolution, dtype', RESOLUTION_DTYPE_TEST_SET)
183+
def test_config(resolution, dtype):
170184
"""Test the configuration of the Date class."""
171-
partitioning = Date(('dates', ), 'D')
172-
assert partitioning.dtype() == (('year', 'uint16'), ('month', 'uint8'),
173-
('day', 'uint8'))
185+
partitioning = Date(variables=('dates', ), resolution=resolution)
186+
assert partitioning.dtype() == dtype
187+
174188
config = partitioning.get_config()
175-
partitioning = get_codecs(config)
176-
assert isinstance(partitioning, Date)
189+
other = get_codecs(config)
190+
191+
assert isinstance(other, Date)
192+
assert other.variables == ('dates', )
193+
assert other.dtype() == dtype
177194

178195

179-
def test_pickle():
196+
@pytest.mark.parametrize('resolution, dtype', RESOLUTION_DTYPE_TEST_SET)
197+
def test_pickle(resolution, dtype):
180198
"""Test the pickling of the Date class."""
181-
partitioning = Date(('dates', ), 'D')
199+
partitioning = Date(('dates', ), resolution=resolution)
182200
other = pickle.loads(pickle.dumps(partitioning))
201+
183202
assert isinstance(other, Date)
184-
assert other.resolution == 'D'
203+
assert other.resolution == resolution
185204
assert other.variables == ('dates', )
205+
assert other.dtype() == dtype
186206

187207

188208
@pytest.mark.parametrize('delayed', [False, True])

zcollection/partitioning/tests/test_sequence.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,32 @@ def test_split_dataset(
113113
list(partitioning.split_dataset(zds, 'num_lines'))
114114

115115

116-
def test_config() -> None:
116+
VARIABLES_DTYPE_TEST_SET = [(('a', ), None), (('a', ), ('uint8', )),
117+
(('a', 'b'), None),
118+
(('a', 'b'), ('int8', 'int16'))]
119+
120+
121+
@pytest.mark.parametrize('variables, dtype', VARIABLES_DTYPE_TEST_SET)
122+
def test_config(variables, dtype) -> None:
117123
"""Test the configuration of the Sequence class."""
118-
partitioning = Sequence(('cycle_number', 'pass_number'))
124+
partitioning = Sequence(variables=variables, dtype=dtype)
125+
119126
config = partitioning.get_config()
120-
partitioning = get_codecs(config) # type: ignore[assignment]
121-
assert isinstance(partitioning, Sequence)
127+
other = get_codecs(config) # type: ignore[assignment]
122128

129+
assert isinstance(other, Sequence)
130+
assert other.dtype() == partitioning.dtype()
123131

124-
def test_pickle() -> None:
132+
133+
@pytest.mark.parametrize('variables, dtype', VARIABLES_DTYPE_TEST_SET)
134+
def test_pickle(variables, dtype) -> None:
125135
"""Test the pickling of the Date class."""
126-
partitioning = Sequence(('cycle_number', 'pass_number'))
136+
partitioning = Sequence(variables=variables, dtype=dtype)
137+
127138
other = pickle.loads(pickle.dumps(partitioning))
139+
128140
assert isinstance(other, Sequence)
129-
assert other.variables == ('cycle_number', 'pass_number')
141+
assert other.dtype() == partitioning.dtype()
130142

131143

132144
# pylint: disable=protected-access

0 commit comments

Comments
 (0)