Skip to content

Commit da04a51

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Auto apply @dataclass to dataclass_array
PiperOrigin-RevId: 516194659
1 parent 4c54d15 commit da04a51

File tree

9 files changed

+56
-61
lines changed

9 files changed

+56
-61
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2626
* **Add `torch` support!**
2727
* Add `.cpu()`, `.cuda()`, `.to()` methods to move the dataclass from
2828
devices when using torch.
29+
* **Breaking**: `@dataclass(frozen=True)` is now automatically applied
2930

3031
## [1.3.0] - 2023-01-16
3132

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import dataclass_array as dca
3030
from dataclass_array.typing import FloatArray
3131

3232

33-
@dataclasses.dataclass(frozen=True)
3433
class Ray(dca.DataclassArray):
3534
pos: FloatArray['*batch_shape 3']
3635
dir: FloatArray['*batch_shape 3']
@@ -75,7 +74,6 @@ A `DataclassArray` has 2 types of fields:
7574
Static fields are also ignored in `jax.tree_map`.
7675

7776
```python
78-
@dataclasses.dataclass(frozen=True)
7977
class MyArray(dca.DataclassArray):
8078
# Array fields
8179
a: FloatArray['*batch_shape 3'] # Defined by `etils.array_types`
@@ -102,7 +100,6 @@ batching:
102100
2. Decorate the method with `dca.vectorize_method`
103101

104102
```python
105-
@dataclasses.dataclass(frozen=True)
106103
class Camera(dca.DataclassArray):
107104
K: FloatArray['*batch_shape 4 4']
108105
resolution = tuple[int, int]
@@ -133,7 +130,7 @@ rays.shape == (num_cams, h, w)
133130
* Instead of vectorizing a single axis, `@dca.vectorize_method` will vectorize
134131
over `*self.shape` (not just `self.shape[0]`). This is like if `vmap` was
135132
applied to `self.flatten()`
136-
* When multiple arguments, axis with dimension `1` are brodcasted.
133+
* When multiple arguments, axis with dimension `1` are broadcasted.
137134

138135
For example, with `__matmul__(self, x: T) -> T`:
139136

dataclass_array/array_dataclass.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def dataclass_array(
8686
8787
```python
8888
@dca.dataclass_array()
89-
@dataclasses.dataclass(frozen=True)
9089
class MyDataclass(dca.DataclassArray):
9190
...
9291
```
@@ -119,6 +118,31 @@ def decorator(cls):
119118
return decorator
120119

121120

121+
def array_field(
122+
shape: Shape,
123+
dtype: DTypeArg = float,
124+
**field_kwargs,
125+
) -> dataclasses.Field[DcOrArray]:
126+
"""Dataclass array field.
127+
128+
See `dca.DataclassArray` for example.
129+
130+
Args:
131+
shape: Inner shape of the field
132+
dtype: Type of the field
133+
**field_kwargs: Args forwarded to `dataclasses.field`
134+
135+
Returns:
136+
The dataclass field.
137+
"""
138+
# TODO(epot): Validate shape, dtype
139+
dca_field = _ArrayFieldMetadata(
140+
inner_shape_non_static=shape,
141+
dtype=dtype,
142+
)
143+
return dataclasses.field(**field_kwargs, metadata={_METADATA_KEY: dca_field})
144+
145+
122146
class MetaDataclassArray(type):
123147
"""DataclassArray metaclass."""
124148

@@ -128,13 +152,21 @@ def __getitem__(cls, spec):
128152
return Annotated[cls, field_utils.ShapeAnnotation(spec)]
129153

130154

155+
@typing_extensions.dataclass_transform( # pytype: disable=not-supported-yet
156+
kw_only_default=True,
157+
# TODO(b/272524683):Restore field specifier
158+
# field_specifiers=(
159+
# dataclasses.Field,
160+
# dataclasses.field,
161+
# array_field,
162+
# ),
163+
)
131164
class DataclassArray(metaclass=MetaDataclassArray):
132165
"""Dataclass which behaves like an array.
133166
134167
Usage:
135168
136169
```python
137-
@dataclasses.dataclass
138170
class Square(DataclassArray):
139171
pos: f32['*shape 2']
140172
scale: f32['*shape']
@@ -179,7 +211,6 @@ class Square(DataclassArray):
179211
180212
Field which do not satisfy any of the above conditions are static (including
181213
field annotated with `field: np.ndarray` or similar).
182-
183214
"""
184215

185216
# Child class inherit the default params by default, but can also
@@ -194,8 +225,21 @@ class Square(DataclassArray):
194225
_shape: Shape
195226
_xnp: enp.NpModule
196227

197-
def __init_subclass__(cls, **kwargs):
228+
def __init_subclass__(
229+
cls,
230+
frozen=True,
231+
**kwargs,
232+
):
198233
super().__init_subclass__(**kwargs)
234+
235+
if not frozen:
236+
raise ValueError(f'{cls} cannot be `frozen=False`.')
237+
238+
# Apply dataclass (in-place)
239+
if not typing.TYPE_CHECKING:
240+
# TODO(b/227290126): Create pytype issues
241+
dataclasses.dataclass(frozen=True)(cls)
242+
199243
# TODO(epot): Could have smart __repr__ which display types if array have
200244
# too many values (maybe directly in `edc.field(repr=...)`).
201245
edc.dataclass(kw_only=True, repr=True, auto_cast=False)(cls)
@@ -212,6 +256,11 @@ def __init_subclass__(cls, **kwargs):
212256
# `__dca_non_init_fields__` (fields should be merged from `.mro()`)
213257
cls.__dca_non_init_fields__ = set(cls.__dca_non_init_fields__)
214258

259+
if typing.TYPE_CHECKING:
260+
# TODO(b/242839979): pytype do not support PEP 681 -- Data Class Transforms
261+
def __init__(self, **kwargs):
262+
pass
263+
215264
def __post_init__(self) -> None:
216265
"""Validate and normalize inputs."""
217266
cls = type(self)
@@ -755,7 +804,6 @@ def _init_cls(self: DataclassArray) -> None:
755804
756805
This will:
757806
758-
* Validate the `@dataclass(frozen=True)` is correctly applied
759807
* Extract the types annotations, detect which fields are arrays or static,
760808
and store the result in `_dca_fields_metadata`
761809
* For static `DataclassArray` (class with only static fields), it will
@@ -767,12 +815,6 @@ def _init_cls(self: DataclassArray) -> None:
767815
"""
768816
cls = type(self)
769817

770-
# Make sure the dataclass was registered and frozen
771-
if not dataclasses.is_dataclass(cls) or not cls.__dataclass_params__.frozen: # pytype: disable=attribute-error
772-
raise ValueError(
773-
'`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
774-
)
775-
776818
# The first time, compute typing annotations & metadata
777819
# At this point, `ForwardRef` should have been resolved.
778820
try:
@@ -797,7 +839,7 @@ def _init_cls(self: DataclassArray) -> None:
797839
)
798840

799841
dca_fields_metadata = {
800-
f.name: _make_field_metadata(f, hints) for f in dataclasses.fields(cls)
842+
f.name: _make_field_metadata(f, hints) for f in dataclasses.fields(cls) # pytype: disable=wrong-arg-types
801843
}
802844
dca_fields_metadata = { # Filter `None` values (static fields)
803845
k: v for k, v in dca_fields_metadata.items() if v is not None
@@ -909,31 +951,6 @@ class _TreeMetadata:
909951
non_array_field_kwargs: dict[str, Any]
910952

911953

912-
def array_field(
913-
shape: Shape,
914-
dtype: DTypeArg = float,
915-
**field_kwargs,
916-
) -> dataclasses.Field[DcOrArray]:
917-
"""Dataclass array field.
918-
919-
See `dca.DataclassArray` for example.
920-
921-
Args:
922-
shape: Inner shape of the field
923-
dtype: Type of the field
924-
**field_kwargs: Args forwarded to `dataclasses.field`
925-
926-
Returns:
927-
The dataclass field.
928-
"""
929-
# TODO(epot): Validate shape, dtype
930-
dca_field = _ArrayFieldMetadata(
931-
inner_shape_non_static=shape,
932-
dtype=dtype,
933-
)
934-
return dataclasses.field(**field_kwargs, metadata={_METADATA_KEY: dca_field})
935-
936-
937954
# TODO(epot): Should refactor `_ArrayField` in `_DataclassArrayField` and
938955
# `_ArrayField` depending on whether dtype is `DataclassArray` or not.
939956
# Alternativelly, maybe should create a `DcArrayDType` dtype instead.

dataclass_array/array_dataclass_test.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from __future__ import annotations
1818

19-
import dataclasses
2019
from typing import Any
2120

2221
import dataclass_array as dca
@@ -49,7 +48,6 @@ def assert_val(
4948

5049

5150
@dca.dataclass_array(broadcast=True, cast_dtype=True)
52-
@dataclasses.dataclass(frozen=True)
5351
class Point(dca.DataclassArray):
5452
x: f32['*shape']
5553
y: f32['*shape']
@@ -77,7 +75,6 @@ def assert_val(p: Point, shape: Shape, xnp: enp.NpModule = None):
7775

7876

7977
@dca.dataclass_array(broadcast=True, cast_dtype=True)
80-
@dataclasses.dataclass(frozen=True)
8178
class Isometrie(dca.DataclassArray):
8279
r: f32['... 3 3']
8380
t: i32[..., 2]
@@ -105,7 +102,6 @@ def assert_val(p: Isometrie, shape: Shape, xnp: enp.NpModule = None):
105102

106103

107104
@dca.dataclass_array(broadcast=True, cast_dtype=True)
108-
@dataclasses.dataclass(frozen=True)
109105
class Nested(dca.DataclassArray):
110106
# pytype: disable=annotation-type-mismatch
111107
iso: Isometrie
@@ -142,7 +138,6 @@ def assert_val(p: Nested, shape: Shape, xnp: enp.NpModule = None):
142138
Isometrie.assert_val(p.iso_batched, shape=shape + (3, 7), xnp=xnp)
143139

144140

145-
@dataclasses.dataclass(frozen=True)
146141
class OnlyStatic(dca.DataclassArray):
147142
"""Dataclass with no array fields."""
148143

@@ -172,7 +167,6 @@ def assert_val(p: OnlyStatic, shape: Shape, xnp: enp.NpModule = None):
172167
assert p.y == 1
173168

174169

175-
@dataclasses.dataclass(frozen=True)
176170
class NestedOnlyStatic(dca.DataclassArray):
177171
"""Dataclass with only nested array fields."""
178172

@@ -198,7 +192,6 @@ def assert_val(p: NestedOnlyStatic, shape: Shape, xnp: enp.NpModule = None):
198192

199193

200194
@dca.dataclass_array(broadcast=True, cast_dtype=True)
201-
@dataclasses.dataclass(frozen=True)
202195
class WithStatic(dca.DataclassArray):
203196
"""Mix of static and array fields."""
204197

@@ -413,7 +406,6 @@ def test_wrong_input_type():
413406
)
414407

415408
@dca.dataclass_array(broadcast=True, cast_dtype=True)
416-
@dataclasses.dataclass(frozen=True)
417409
class PointWrapper(dca.DataclassArray):
418410
pts: Point
419411
rgb: f32['*shape 3']
@@ -642,7 +634,6 @@ def fn(p: WithStatic) -> WithStatic:
642634

643635
@enp.testing.parametrize_xnp()
644636
def test_dataclass_params_no_cast(xnp: enp.NpModule):
645-
@dataclasses.dataclass(frozen=True)
646637
class PointNoCast(dca.DataclassArray):
647638
x: FloatArray['*shape']
648639
y: IntArray['*shape']
@@ -665,7 +656,6 @@ class PointNoCast(dca.DataclassArray):
665656
@enp.testing.parametrize_xnp()
666657
def test_dataclass_params_no_list(xnp: enp.NpModule):
667658
@dca.dataclass_array(cast_list=False)
668-
@dataclasses.dataclass(frozen=True)
669659
class PointNoList(dca.DataclassArray):
670660
x: FloatArray['*shape']
671661
y: IntArray['*shape']
@@ -679,7 +669,6 @@ class PointNoList(dca.DataclassArray):
679669

680670
@enp.testing.parametrize_xnp()
681671
def test_dataclass_params_no_broadcast(xnp: enp.NpModule):
682-
@dataclasses.dataclass(frozen=True)
683672
class PointNoBroadcast(dca.DataclassArray):
684673
x: FloatArray['*shape']
685674
y: IntArray['*shape']
@@ -694,7 +683,6 @@ class PointNoBroadcast(dca.DataclassArray):
694683
@enp.testing.parametrize_xnp()
695684
@pytest.mark.parametrize('batch_shape', [(), (1, 3)])
696685
def test_dataclass_none_shape(xnp: enp.NpModule, batch_shape: Shape):
697-
@dataclasses.dataclass(frozen=True)
698686
class PointDynamicShape(dca.DataclassArray):
699687
x: FloatArray[..., None, None]
700688
y: IntArray['... 3 _']

dataclass_array/import_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@
2020

2121
from __future__ import annotations
2222

23-
import dataclasses
2423
import sys
2524

2625
import dataclass_array as dca
2726
from etils import enp
2827

2928

30-
@dataclasses.dataclass(frozen=True)
3129
class A(dca.DataclassArray):
3230
x: dca.typing.f32['*s']
3331

dataclass_array/testing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from __future__ import annotations
1818

19-
import dataclasses
2019
import functools
2120
from typing import Any, Optional
2221

@@ -27,7 +26,6 @@
2726
import numpy as np
2827

2928

30-
@dataclasses.dataclass(frozen=True)
3129
class Ray(array_dataclass.DataclassArray):
3230
"""Dummy dataclass array for testing."""
3331

dataclass_array/type_parsing_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from __future__ import annotations
1818

19-
import dataclasses
2019
from typing import Optional, List, Union
2120

2221
import dataclass_array as dca
@@ -28,7 +27,6 @@
2827
Ray = dca.testing.Ray
2928

3029

31-
@dataclasses.dataclass(frozen=True)
3230
class Camera(dca.DataclassArray):
3331
pos: FloatArray[..., 3]
3432
dir: FloatArray[..., 3]

dataclass_array/vectorization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def vectorize_method(
100100
Example:
101101
102102
```
103-
@dataclasses.dataclass(frozen=True)
104103
class Point3d(dca.DataclassArray):
105104
p: f32['*shape 3']
106105

dataclass_array/vectorization_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def fn(self, arg):
150150
def test_replace_dca(xnp: enp.NpModule):
151151
# Ensure that the non-init static fields are correctly forwarded.
152152

153-
@dataclasses.dataclass(frozen=True)
154153
class DataclassWithNonInit(dca.DataclassArray):
155154
"""Dataclass with a non-init (static) field."""
156155

0 commit comments

Comments
 (0)