Skip to content

Commit 7feefc1

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Support typing annotation shape for dataclass array (e.g. ray: Ray['*batch_shape _ _'])
PiperOrigin-RevId: 466946356
1 parent 43695c3 commit 7feefc1

File tree

10 files changed

+194
-40
lines changed

10 files changed

+194
-40
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Added: Syntax to specify the shape of the DataclassArray (e.g. `MyRay['h
27+
w']`).
28+
2629
## [1.0.0] - 2022-08-08
2730

2831
* Initial release
2932

30-
3133
[Unreleased]: https://github.com/google-research/dataclass_array/compare/v1.0.0...HEAD
3234
[1.0.0]: https://github.com/google-research/dataclass_array/compare/v0.1.0...v1.0.0
3335
[0.1.0]: https://github.com/google-research/dataclass_array/releases/tag/v0.1.0

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,18 @@ A `DataclassArray` has 2 types of fields:
7777
class MyArray(dca.DataclassArray):
7878
# Array fields
7979
a: FloatArray['*batch_shape 3'] # Defined by `etils.array_types`
80-
b: Ray # Nested DataclassArray (inner shape == `()`)
80+
b: FloatArray['*batch_shape _ _'] # Dynamic shape
81+
c: Ray # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)
82+
d: Ray['*batch_shape 6']
8183

8284
# Array fields explicitly defined
83-
c: Any = dca.field(shape=(3,), dtype=np.float32)
84-
d: Ray = dca.field(shape=(3,), dtype=Ray) # Nested DataclassArray
85+
e: Any = dca.field(shape=(3,), dtype=np.float32)
86+
f: Any = dca.field(shape=(None, None), dtype=np.float32) # Dynamic shape
87+
g: Ray = dca.field(shape=(3,), dtype=Ray) # Nested DataclassArray
8588

8689
# Static field (everything not defined as above)
87-
e: float
88-
f: np.array
90+
static0: float
91+
static1: np.array
8992
```
9093

9194
### Vectorization

dataclass_array/array_dataclass.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import typing
2121
from typing import Any, Callable, ClassVar, Generic, Iterator, Optional, Tuple, Type, TypeVar, Union
2222

23+
from dataclass_array import field_utils
2324
from dataclass_array import shape_parsing
2425
from dataclass_array import type_parsing
25-
from dataclass_array.typing import Axes, DcOrArray, DcOrArrayT, DTypeArg, Shape # pylint: disable=g-multiple-import
26+
from dataclass_array.typing import Axes, DcOrArray, DcOrArrayT, DTypeArg, DynamicShape, Shape # pylint: disable=g-multiple-import
2627
from dataclass_array.utils import np_utils
2728
from dataclass_array.utils import py_utils
2829
import einops
@@ -32,7 +33,8 @@
3233
from etils import epy
3334
from etils.array_types import Array
3435
import numpy as np
35-
from typing_extensions import Literal, TypeAlias # pylint: disable=g-multiple-import
36+
import typing_extensions
37+
from typing_extensions import Annotated, Literal, TypeAlias # pylint: disable=g-multiple-import
3638

3739

3840
lazy = enp.lazy
@@ -65,7 +67,16 @@ class DataclassParams:
6567
cast_list: bool = True
6668

6769

68-
class DataclassArray:
70+
class MetaDataclassArray(type):
71+
"""DataclassArray metaclass."""
72+
73+
# TODO(b/204422756): We cannot use `__class_getitem__` due to b/204422756
74+
def __getitem__(cls, spec):
75+
# Not clear how this would interact if cls is also a `Generic`
76+
return Annotated[cls, field_utils.ShapeAnnotation(spec)]
77+
78+
79+
class DataclassArray(metaclass=MetaDataclassArray):
6980
"""Dataclass which behaves like an array.
7081
7182
Usage:
@@ -385,7 +396,7 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
385396
if cls._dca_fields_metadata is None: # pylint: disable=protected-access
386397
# At this point, `ForwardRef` should have been resolved.
387398
try:
388-
hints = typing.get_type_hints(cls)
399+
hints = typing_extensions.get_type_hints(cls, include_extras=True)
389400
except Exception as e: # pylint: disable=broad-except
390401
epy.reraise(
391402
e,
@@ -706,7 +717,7 @@ class _ArrayFieldMetadata:
706717
dtype: Type of the array. Can be `array_types.dtypes.DType` or
707718
`dca.DataclassArray` for nested arrays.
708719
"""
709-
inner_shape_non_static: Shape
720+
inner_shape_non_static: DynamicShape
710721
dtype: Union[array_types.dtypes.DType, type[DataclassArray]]
711722

712723
def __post_init__(self):
@@ -739,7 +750,7 @@ class _ArrayField(_ArrayFieldMetadata, Generic[DcOrArrayT]):
739750
host: Dataclass instance who this field is attached too
740751
"""
741752
name: str
742-
host: DataclassArray
753+
host: DataclassArray = dataclasses.field(repr=False)
743754

744755
@property
745756
def qualname(self) -> str:
@@ -851,18 +862,18 @@ def _make_field_metadata(
851862
def _type_to_field_metadata(hint: TypeAlias) -> Optional[_ArrayFieldMetadata]:
852863
"""Converts type hint to extract `inner_shape`, `dtype`."""
853864
array_type = type_parsing.get_array_type(hint)
854-
if isinstance(array_type, type) and issubclass(array_type, DataclassArray):
855-
# TODO(epot): Should support `ray: Ray[..., 3]` ?
856-
return _ArrayFieldMetadata(inner_shape_non_static=(), dtype=array_type)
865+
866+
if isinstance(array_type, field_utils.DataclassWithShape):
867+
dtype = array_type.cls
857868
elif isinstance(array_type, array_types.ArrayAliasMeta):
858-
assert array_type is not None
859-
try:
860-
return _ArrayFieldMetadata(
861-
inner_shape_non_static=shape_parsing.get_inner_shape(
862-
array_type.shape),
863-
dtype=array_type.dtype,
864-
)
865-
except Exception as e: # pylint: disable=broad-except
866-
epy.reraise(e, prefix=f'Invalid shape annotation {hint}.')
869+
dtype = array_type.dtype
867870
else: # Not a supported type: Static field
868871
return None
872+
873+
try:
874+
return _ArrayFieldMetadata(
875+
inner_shape_non_static=shape_parsing.get_inner_shape(array_type.shape),
876+
dtype=dtype,
877+
)
878+
except Exception as e: # pylint: disable=broad-except
879+
epy.reraise(e, prefix=f'Invalid shape annotation {hint}.')

dataclass_array/array_dataclass_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class Isometrie(dca.DataclassArray):
5656
class Nested(dca.DataclassArray):
5757
# pytype: disable=annotation-type-mismatch
5858
iso: Isometrie
59-
iso_batched: Isometrie = dca.field(shape=(3, 7), dtype=Isometrie)
59+
iso_batched: Isometrie['*batch_shape 3 7']
6060
pt: Point = dca.field(shape=(3,), dtype=Point)
6161
# pytype: enable=annotation-type-mismatch
6262

@@ -645,3 +645,12 @@ class PointDynamicShape(dca.DataclassArray):
645645
x=xnp.zeros(batch_shape + (2, 3), dtype=np.float32),
646646
y=xnp.zeros(batch_shape + (2, 1), dtype=np.int32), # < 2 != 3
647647
)
648+
649+
650+
def test_class_getitem():
651+
assert Point == Point # pylint: disable=comparison-with-itself
652+
assert Point[''] == Point['']
653+
assert Point['h w'] == Point['h w']
654+
assert Point[''] != Point
655+
assert Point[''] != Point['h w']
656+
assert Point[''] != Isometrie['']

dataclass_array/field_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2022 The dataclass_array Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Field utils."""
16+
17+
from __future__ import annotations
18+
19+
import dataclasses
20+
21+
from dataclass_array import array_dataclass
22+
from dataclass_array.typing import TypeAlias
23+
from etils import epy
24+
from etils.array_types import Array
25+
import typing_extensions
26+
27+
28+
@dataclasses.dataclass(eq=True, frozen=True)
29+
class ShapeAnnotation:
30+
"""Annotations for `Ray[''] == Annotated[Ray, ShapeAnnotation('')]`."""
31+
shape: str
32+
33+
def __post_init__(self):
34+
# Normalize shape
35+
# Might be a cleaner way to do this
36+
super().__setattr__('shape', Array[self.shape].shape)
37+
38+
39+
@dataclasses.dataclass(eq=True, frozen=True)
40+
class DataclassWithShape:
41+
"""Structure which represent `Ray['h w']`."""
42+
cls: type[array_dataclass.DataclassArray]
43+
shape: str
44+
45+
@classmethod
46+
def from_hint(cls, hint: TypeAlias) -> DataclassWithShape:
47+
"""Factory to create the `DataclassWithShape` from `MyDca['h w']`."""
48+
assert cls.is_dca(hint)
49+
50+
# Extract the shape
51+
shape = '...'
52+
if typing_extensions.get_origin(hint) is typing_extensions.Annotated:
53+
shapes = [a for a in hint.__metadata__ if isinstance(a, ShapeAnnotation)] # pytype: disable=attribute-error
54+
if len(shapes) > 1:
55+
raise ValueError(f'Conflicting annotations for {hint}')
56+
elif len(shapes) == 1:
57+
(shape,) = shapes
58+
shape = shape.shape
59+
60+
hint = hint.__origin__
61+
return cls(cls=hint, shape=shape)
62+
63+
@classmethod
64+
def is_dca(cls, hint: TypeAlias) -> bool:
65+
if typing_extensions.get_origin(hint) is typing_extensions.Annotated:
66+
hint = hint.__origin__
67+
return epy.issubclass(hint, array_dataclass.DataclassArray)

dataclass_array/shape_parsing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def get_inner_shape(shape_str: str) -> _Shape:
9191
# TODO(epot): Support `_` & `None` dim
9292
if not shape or not isinstance(shape[0], _VarDim):
9393
raise ValueError(
94-
'Shape should start by `...` or `*shape` (e.g. `f32[\'*shape 3\']`)')
94+
f'Shape {shape_str!r} should start by `...` or `*shape` (e.g. '
95+
'`f32[\'*shape 3\']`)')
9596

9697
inner_shape = shape[1:]
9798
if not all(isinstance(dim, (int, type(None))) for dim in inner_shape):

dataclass_array/shape_parsing_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_get_inner_shape(shape_str, shape_tuple):
8686
],
8787
)
8888
def test_get_inner_shape_failure_first_dim(shape_str: str):
89-
with pytest.raises(ValueError, match='Shape should start'):
89+
with pytest.raises(ValueError, match='Shape .* should start'):
9090
shape_parsing.get_inner_shape(shape_str)
9191

9292

dataclass_array/type_parsing.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
from typing import Any, Callable, Optional
2323

2424
from dataclass_array import array_dataclass
25+
from dataclass_array import field_utils
26+
from dataclass_array.typing import TypeAlias
2527
from etils import array_types as array_types_lib
2628
import typing_extensions # TODO(py38): Remove
2729

28-
TypeAlias = Any # TODO(py39): Use real alias once 3.7 is dropped.
2930
_LeafFn = Callable[[TypeAlias], None]
3031

3132
_NoneType = type(None)
@@ -74,15 +75,15 @@ def _collect_leaf_types(hint):
7475
return all_types
7576

7677

77-
def get_array_type(hint: TypeAlias) -> Optional[type[Any]]:
78+
def get_array_type(hint: TypeAlias) -> Optional[Any]:
7879
"""Returns the array type, or `None` if no type was detected.
7980
8081
Example:
8182
8283
```python
8384
get_array_type(f32[..., 3]) -> f32[..., 3]
84-
get_array_type(dca.Ray) -> dca.Ray
85-
get_array_type(Optional[dca.Ray]) -> dca.Ray
85+
get_array_type(dca.Ray) -> dca.Ray['...']
86+
get_array_type(Optional[dca.Ray]) -> dca.Ray['...']
8687
get_array_type(dca.Ray | dca.Camera | None) -> dca.DataclassArray
8788
get_array_type(Any) -> None # Any not an array type
8889
get_array_type(dca.Ray | int) -> None # int not an array type
@@ -106,9 +107,8 @@ def get_array_type(hint: TypeAlias) -> Optional[type[Any]]:
106107
array_types = []
107108
other_types = []
108109
for leaf in leaf_types:
109-
if (isinstance(leaf, type) and
110-
issubclass(leaf, array_dataclass.DataclassArray)):
111-
dc_types.append(leaf)
110+
if field_utils.DataclassWithShape.is_dca(leaf):
111+
dc_types.append(field_utils.DataclassWithShape.from_hint(leaf))
112112
elif isinstance(leaf, array_types_lib.ArrayAliasMeta):
113113
array_types.append(leaf)
114114
else:
@@ -122,7 +122,17 @@ def get_array_type(hint: TypeAlias) -> Optional[type[Any]]:
122122
'this feature.')
123123
if dc_types:
124124
if len(dc_types) > 1:
125-
return array_dataclass.DataclassArray
125+
# Validate the inner shape
126+
common_shapes = {x.shape for x in dc_types}
127+
if len(common_shapes) != 1:
128+
raise NotImplementedError(
129+
f'{hint} mix dataclass with different inner shape. Please open an '
130+
'issue if you need this feature.')
131+
(common_shape,) = common_shapes
132+
return field_utils.DataclassWithShape(
133+
cls=array_dataclass.DataclassArray,
134+
shape=common_shape,
135+
)
126136
else:
127137
return dc_types[0]
128138
if array_types:

dataclass_array/type_parsing_test.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import pytest
2626

2727

28+
_DS = dca.field_utils.DataclassWithShape
2829
Ray = dca.testing.Ray
2930

3031

@@ -39,7 +40,10 @@ class Camera(dca.DataclassArray):
3940
[
4041
(int, [int]),
4142
(Ray, [Ray]),
43+
(Ray['h w'], [Ray['h w']]),
44+
(Ray[..., 3], [Ray[..., 3]]),
4245
(Union[Ray, int], [Ray, int]),
46+
(Union[Ray['h w'], int], [Ray['h w'], int]),
4347
(Union[Ray, int, None], [Ray, int, None]),
4448
(Optional[Ray], [Ray, None]),
4549
(Optional[Union[Ray, int]], [Ray, int, None]),
@@ -55,10 +59,12 @@ def test_get_leaf_types(hint, expected):
5559
'hint, expected',
5660
[
5761
(int, None),
58-
(Ray, Ray),
59-
(Optional[Ray], Ray),
60-
(Union[Ray, Camera], dca.DataclassArray),
61-
(Union[Ray, Camera, None], dca.DataclassArray),
62+
(Ray, _DS(Ray, '...')),
63+
(Ray['h w'], _DS(Ray, 'h w')),
64+
(Ray[..., 3], _DS(Ray, '... 3')),
65+
(Optional[Ray], _DS(Ray, '...')),
66+
(Union[Ray, Camera], _DS(dca.DataclassArray, '...')),
67+
(Union[Ray, Camera, None], _DS(dca.DataclassArray, '...')),
6268
(Union[Ray, int], None),
6369
(Union[Ray, int, None], None),
6470
(Union[f32[3, 3], int, None], None),
@@ -72,9 +78,51 @@ def test_get_array_type(hint, expected):
7278
assert type_parsing.get_array_type(hint) == expected
7379

7480

81+
@pytest.mark.parametrize(
82+
'hint, expected',
83+
[
84+
(Ray, _DS(Ray, '...')),
85+
(Ray['h w'], _DS(Ray, 'h w')),
86+
(Ray[..., 3], _DS(Ray, '... 3')),
87+
],
88+
)
89+
def test_from_hint(hint, expected):
90+
assert dca.field_utils.DataclassWithShape.from_hint(hint) == expected
91+
92+
7593
def test_get_array_type_error():
7694
with pytest.raises(NotImplementedError):
7795
type_parsing.get_array_type(Union[Ray, f32[3, 3]])
7896

7997
with pytest.raises(NotImplementedError):
8098
type_parsing.get_array_type(Union[FloatArray[..., 3], f32[3, 3]])
99+
100+
101+
@pytest.mark.parametrize(
102+
'hint, expected',
103+
[
104+
(
105+
Ray,
106+
dca.array_dataclass._ArrayFieldMetadata(
107+
inner_shape_non_static=(),
108+
dtype=Ray,
109+
),
110+
),
111+
(
112+
Ray[..., 3],
113+
dca.array_dataclass._ArrayFieldMetadata(
114+
inner_shape_non_static=(3,),
115+
dtype=Ray,
116+
),
117+
),
118+
(
119+
Ray['*shape 4 _'],
120+
dca.array_dataclass._ArrayFieldMetadata(
121+
inner_shape_non_static=(4, None),
122+
dtype=Ray,
123+
),
124+
),
125+
],
126+
)
127+
def test_type_to_field_metadata(hint, expected):
128+
assert dca.array_dataclass._type_to_field_metadata(hint) == expected

0 commit comments

Comments
 (0)