Skip to content

Commit 13e9adc

Browse files
Merge pull request #61 from p2p-ld/zero-length-dimensinos
Support leading zero length dimensions
2 parents 1dab66f + e39e3d2 commit 13e9adc

File tree

12 files changed

+1207
-951
lines changed

12 files changed

+1207
-951
lines changed

docs/changelog.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
# Changelog
22

3+
## Upcoming
4+
5+
**Bugfix**
6+
7+
- [#60](https://github.com/p2p-ld/numpydantic/issues/60), [#61](https://github.com/p2p-ld/numpydantic/pulls/61) -
8+
Support the edge case of leading zero-length dimensions while round tripping to JSON
9+
10+
**Added*
11+
12+
- [#61](https://github.com/p2p-ld/numpydantic/pulls/61) -
13+
Serialize `shape` in `round_trip` mode, and attempt to reshape when validating from round trip.
14+
This supports the edge case of a leading zero-length dimension.
15+
- Since the `dtype` of an array of objects can't be detected when there are no objects,
16+
allow `Any` to be used in validation to indicate that the dtype could be anything/is unknowable.
17+
18+
**Testing**
19+
20+
- Use static lists rather than generators in test case `merged_product` collections
21+
rather than a generator to avoid the footgun of exhausting a generator once,
22+
and then on the next iteration it appears as an empty collection,
23+
which would cause tests to invisibly not be run.
24+
325
## 1.*
426

527
### 1.6.*

docs/serialization.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,23 @@ DaskJsonDict.is_valid(roundtrip_json)
135135
NumpyJsonDict.is_valid(roundtrip_json)
136136
```
137137

138+
````{note} Zero-length dimensions
139+
Zero-length dimensions are a special case in many array frameworks.
140+
When serializing to JSON, in most cases, zero-length dimensions are lost.
141+
142+
Numpy has some odd edge cases with zero-length arrays,
143+
e.g. an empty array can have a non-zero shape:
144+
145+
```python
146+
>>> np.random.rand(0,3)
147+
array([], shape=(0, 3), dtype=float64)
148+
```
149+
150+
When round-tripping, if the array framework supports reshape operations,
151+
the `shape` of the array will be preserved in the JSON form,
152+
and during validation the interface will attempt to reshape the input array to the stored shape.
153+
````
154+
138155
#### Controlling paths
139156

140157
When possible, the full content of the array is omitted in favor

pdm.lock

Lines changed: 1033 additions & 916 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ markers = [
147147
"union: union dtypes",
148148
"pipe_union: union dtypes specified with a pipe",
149149
"scalar: scalar values",
150+
"model: dtype annotatiosn are a pydantic model"
150151
]
151152

152153
[tool.black]
@@ -190,7 +191,7 @@ select = [
190191
"T100",
191192
]
192193
ignore = [
193-
"ANN101", "ANN102", "ANN401", "ANN204",
194+
"ANN401", "ANN204",
194195
# explicit strict arg for zip
195196
"B905",
196197
# builtin type annotations

src/numpydantic/interface/dask.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ class DaskJsonDict(JsonDict):
3333
name: str
3434
chunks: Iterable[tuple[int, ...]]
3535
dtype: str
36+
shape: Union[tuple[int, ...], None] = None
3637
value: list
3738

3839
def to_array_input(self) -> DaskArray:
3940
"""Construct a dask array"""
4041
np_array = np.array(self.value, dtype=self.dtype)
42+
if self.shape is not None and np_array.shape != self.shape:
43+
np_array = self.reshape_input(np_array, self.shape)
4144
array = from_array(
4245
np_array,
4346
name=self.name,
@@ -75,28 +78,37 @@ def before_validation(self, array: DaskArray) -> NDArrayType:
7578
Try and coerce dicts that should be model objects into the model objects
7679
"""
7780
try:
78-
if issubclass(self.dtype, BaseModel) and isinstance(
79-
array.reshape(-1)[0].compute(), dict
80-
):
81+
if issubclass(self.dtype, BaseModel):
82+
flat_array = array.reshape(-1)
83+
if len(flat_array) == 0:
84+
return array
8185

82-
def _chunked_to_model(array: np.ndarray) -> np.ndarray:
83-
def _vectorized_to_model(item: Union[dict, BaseModel]) -> BaseModel:
84-
if not isinstance(item, self.dtype):
85-
return self.dtype(**item)
86-
else: # pragma: no cover
87-
return item
86+
if isinstance(flat_array[0].compute(), dict):
8887

89-
return np.vectorize(_vectorized_to_model)(array)
88+
def _chunked_to_model(array: np.ndarray) -> np.ndarray:
89+
def _vectorized_to_model(
90+
item: Union[dict, BaseModel],
91+
) -> BaseModel:
92+
if not isinstance(item, self.dtype):
93+
return self.dtype(**item)
94+
else: # pragma: no cover
95+
return item
9096

91-
array = array.map_blocks(_chunked_to_model, dtype=self.dtype)
97+
return np.vectorize(_vectorized_to_model)(array)
98+
99+
array = array.map_blocks(_chunked_to_model, dtype=self.dtype)
92100
except TypeError:
93101
# fine, dtype isn't a type
94102
pass
95103
return array
96104

97105
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
98106
"""Dask arrays require a compute() call to retrieve a single value"""
99-
return type(array.reshape(-1)[0].compute())
107+
flat_array = array.reshape(-1)
108+
if len(flat_array) == 0:
109+
return Any
110+
else:
111+
return type(flat_array[0].compute())
100112

101113
@classmethod
102114
def enabled(cls) -> bool:
@@ -121,12 +133,15 @@ def to_json(
121133
"""
122134
np_array = np.array(array)
123135
as_json = np_array.tolist()
136+
if not isinstance(as_json, list):
137+
as_json = [as_json]
124138
if info.round_trip:
125139
as_json = DaskJsonDict(
126140
type=cls.name,
127141
value=as_json,
128142
name=array.name,
129143
chunks=array.chunks,
130144
dtype=str(np_array.dtype),
145+
shape=array.shape,
131146
)
132147
return as_json

src/numpydantic/interface/interface.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,25 @@ def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
137137
value = value.to_array_input()
138138
return value
139139

140+
@staticmethod
141+
def reshape_input(value: T, shape: tuple[int, ...]) -> T:
142+
"""
143+
If a `reshape` value is present on the array, and the array shape doesn't match,
144+
attempt to reshape it.
145+
"""
146+
if value.shape != shape:
147+
try:
148+
value = value.reshape(shape)
149+
except ValueError:
150+
warnings.warn(
151+
f"Input data has shape {value.shape}, "
152+
f"but roundtrip form specifies {shape},"
153+
f"and {value.shape} can't be cast to {shape}. "
154+
f"Attempting to proceed with validation without reshaping.",
155+
stacklevel=1,
156+
)
157+
return value
158+
140159

141160
class MarkedJson(BaseModel):
142161
"""
@@ -270,7 +289,7 @@ def before_validation(self, array: Any) -> NDArrayType:
270289

271290
def get_dtype(self, array: NDArrayType) -> DtypeType:
272291
"""
273-
Get the dtype from the input array
292+
Get the dtype from the input array.
274293
"""
275294
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
276295
return self.get_object_dtype(array)
@@ -281,8 +300,15 @@ def get_object_dtype(self, array: NDArrayType) -> DtypeType:
281300
"""
282301
When an array contains an object, get the dtype of the object contained
283302
by the array.
303+
304+
If this method returns `Any`, the dtype validation passes -
305+
used for e.g. empty arrays for which the dtype of the array can't be determined
306+
(since there are no objects).
284307
"""
285-
return type(array.ravel()[0])
308+
try:
309+
return type(array.ravel()[0])
310+
except IndexError:
311+
return Any
286312

287313
def validate_dtype(self, dtype: DtypeType) -> bool:
288314
"""

src/numpydantic/interface/numpy.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,17 @@ class NumpyJsonDict(JsonDict):
2828
type: Literal["numpy"]
2929
dtype: str
3030
value: list
31+
# allow shape to be None for backwards compatibility.
32+
shape: Union[tuple[int, ...], None] = None
3133

3234
def to_array_input(self) -> ndarray:
3335
"""
3436
Construct a numpy array
3537
"""
36-
return np.array(self.value, dtype=self.dtype)
38+
array = np.array(self.value, dtype=self.dtype)
39+
if self.shape is not None and array.shape != self.shape:
40+
array = self.reshape_input(array, self.shape)
41+
return array
3742

3843

3944
class NumpyInterface(Interface):
@@ -82,7 +87,11 @@ def before_validation(self, array: Any) -> ndarray:
8287
array = np.array(array)
8388

8489
try:
85-
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
90+
if (
91+
issubclass(self.dtype, BaseModel)
92+
and len(array) > 0
93+
and isinstance(array.flat[0], dict)
94+
):
8695
array = np.vectorize(lambda x: self.dtype(**x))(array)
8796
except TypeError:
8897
# fine, dtype isn't a type
@@ -110,6 +119,9 @@ def to_json(
110119

111120
if info.round_trip:
112121
json_array = NumpyJsonDict(
113-
type=cls.name, dtype=str(array.dtype), value=json_array
122+
type=cls.name,
123+
dtype=str(array.dtype),
124+
value=json_array,
125+
shape=array.shape,
114126
)
115127
return json_array

src/numpydantic/serialization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def relative_path(self: Path, other: Path, walk_up: bool = True) -> Path:
137137
References:
138138
https://github.com/python/cpython/blob/8a2baedc4bcb606da937e4e066b4b3a18961cace/Lib/pathlib/_abc.py#L244-L270
139139
"""
140-
# pdb.set_trace()
141140
if not isinstance(other, Path): # pragma: no cover - ripped from cpython
142141
other = Path(other)
143142
self_parts = self.parts

src/numpydantic/testing/cases.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class SubClass(BasicModel):
9696
marks={"scalar"},
9797
passes=False,
9898
),
99+
ValidationCase(
100+
annotation_shape=("*", 3), shape=(0, 3), id="zero-length", passes=True
101+
),
99102
)
100103
"""
101104
Base Shape cases
@@ -162,16 +165,32 @@ class SubClass(BasicModel):
162165
marks={"np_str", "str", "tuple"},
163166
),
164167
ValidationCase(
165-
annotation_dtype=BasicModel, dtype=BasicModel, passes=True, id="model-model"
168+
annotation_dtype=BasicModel,
169+
dtype=BasicModel,
170+
passes=True,
171+
id="model-model",
172+
marks={"model"},
166173
),
167174
ValidationCase(
168-
annotation_dtype=BasicModel, dtype=BadModel, passes=False, id="model-badmodel"
175+
annotation_dtype=BasicModel,
176+
dtype=BadModel,
177+
passes=False,
178+
id="model-badmodel",
179+
marks={"model"},
169180
),
170181
ValidationCase(
171-
annotation_dtype=BasicModel, dtype=int, passes=False, id="model-int"
182+
annotation_dtype=BasicModel,
183+
dtype=int,
184+
passes=False,
185+
id="model-int",
186+
marks={"model"},
172187
),
173188
ValidationCase(
174-
annotation_dtype=BasicModel, dtype=SubClass, passes=True, id="model-subclass"
189+
annotation_dtype=BasicModel,
190+
dtype=SubClass,
191+
passes=True,
192+
id="model-subclass",
193+
marks={"model"},
175194
),
176195
ValidationCase(
177196
annotation_dtype=UNION_TYPE,
@@ -305,3 +324,5 @@ class SubClass(BasicModel):
305324
"""
306325
Merged product of all cases, but only those that pass
307326
"""
327+
328+
ZERO_LENGTH_CASES_PASSING = [c for c in ALL_CASES_PASSING if "zero-length" in c.id]

src/numpydantic/testing/helpers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77
from typing import (
88
TYPE_CHECKING,
9-
Generator,
109
List,
1110
Literal,
1211
Optional,
@@ -303,7 +302,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
303302

304303
def merged_product(
305304
*args: Sequence[ValidationCase], conditions: dict = None
306-
) -> Generator[ValidationCase, None, None]:
305+
) -> list[ValidationCase]:
307306
"""
308307
Generator for the product of the iterators of validation cases,
309308
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
@@ -341,6 +340,7 @@ def merged_product(
341340
342341
"""
343342
iterator = product(*args)
343+
cases = []
344344
for case_tuple in iterator:
345345
case = merge_cases(*case_tuple)
346346
if case.skip():
@@ -349,4 +349,5 @@ def merged_product(
349349
matching = all([getattr(case, k, None) == v for k, v in conditions.items()])
350350
if not matching:
351351
continue
352-
yield case
352+
cases.append(case)
353+
return cases

0 commit comments

Comments
 (0)