Skip to content

Commit 7624853

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Add @dca.dataclass_array decorator to customize dca params. Change default values
PiperOrigin-RevId: 475563717
1 parent c2cb66d commit 7624853

File tree

4 files changed

+67
-16
lines changed

4 files changed

+67
-16
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Changed: By default, dataclass_array do not cast and broadcast inputs
27+
anymore.
28+
* Added: `@dca.dataclass_array` to customize the `dca.DataclassArray` params
29+
2630
## [1.1.0] - 2022-08-15
2731

2832
* Added: Array types can be imported directly from `dataclass_array.typing`

dataclass_array/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from dataclass_array import typing
3333
# TODO(epot): Rename array_field -> field internally
3434
from dataclass_array.array_dataclass import array_field as field
35+
from dataclass_array.array_dataclass import dataclass_array
3536
from dataclass_array.array_dataclass import DataclassArray
36-
from dataclass_array.array_dataclass import DataclassParams
3737
from dataclass_array.ops import stack
3838
from dataclass_array.vectorization import vectorize_method
3939

dataclass_array/array_dataclass.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,64 @@
5454
class DataclassParams:
5555
"""Params controlling the DataclassArray behavior.
5656
57-
Saved in `cls.__dca_params__`.
57+
Set by `@dca.dataclass_array`. Saved in `cls.__dca_params__`.
5858
5959
Attributes:
60-
broadcast: If `False`, disable input broadcasting
61-
cast_dtype: If `False`, do not auto-cast inputs `dtype`
62-
cast_list: If `False`, do not auto-cast lists to `xnp.ndarray`
60+
broadcast: If `True`, enable input broadcasting
61+
cast_dtype: If `True`, auto-cast inputs `dtype`
62+
cast_list: If `True`, auto-cast lists to `xnp.ndarray`
6363
"""
64-
broadcast: bool = True
65-
cast_dtype: bool = True
64+
# If modifying this, make sure to modify `@dataclass_array` too!
65+
broadcast: bool = False
66+
cast_dtype: bool = False
6667
cast_list: bool = True
6768

6869

70+
def dataclass_array(
71+
*,
72+
# If modifying this, make sure to modify `DataclassParams` too!
73+
broadcast: bool = False,
74+
cast_dtype: bool = False,
75+
cast_list: bool = True,
76+
) -> Callable[[type[_DcT]], type[_DcT]]:
77+
"""Optional decorator to customize `dca.DataclassArray` params.
78+
79+
Usage:
80+
81+
```python
82+
@dca.dataclass_array()
83+
@dataclasses.dataclass(frozen=True)
84+
class MyDataclass(dca.DataclassArray):
85+
...
86+
```
87+
88+
This decorator has to be added in addition of inheriting from
89+
`dca.DataclassArray`.
90+
91+
Args:
92+
broadcast: If `True`, enable input broadcasting
93+
cast_dtype: If `True`, auto-cast inputs `dtype`
94+
cast_list: If `True`, auto-cast lists to `xnp.ndarray`
95+
96+
Returns:
97+
decorator: The decorator which will apply the options to the dataclass
98+
"""
99+
100+
def decorator(cls):
101+
if not issubclass(cls, DataclassArray):
102+
raise TypeError(
103+
'`@dca.dataclass_array` can only be applied on `dca.DataclassArray`. '
104+
f'Got: {cls}')
105+
cls.__dca_params__ = DataclassParams(
106+
broadcast=broadcast,
107+
cast_dtype=cast_dtype,
108+
cast_list=cast_list,
109+
)
110+
return cls
111+
112+
return decorator
113+
114+
69115
class MetaDataclassArray(type):
70116
"""DataclassArray metaclass."""
71117

@@ -545,8 +591,9 @@ def _broadcast_field(f: _ArrayField) -> None:
545591
return
546592
elif not self.__dca_params__.broadcast: # Broadcasing disabled
547593
raise ValueError(
548-
f'{type(self).__qualname__} has `broadcast=False`. Cannot '
549-
f'broadcast {f.name} from {f.full_shape} to {final_shape}')
594+
f'{type(self).__qualname__} has `broadcast=False`. '
595+
f'Cannot broadcast {f.name} from {f.full_shape} to {final_shape}. '
596+
f'To enable broadcast, use `@dca.dataclass_array(broadcast=True)`.')
550597
self._setattr(f.name, f.broadcast_to(final_shape))
551598

552599
self._map_field(

dataclass_array/array_dataclass_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,28 @@
3434
# TODO(epot): Test dtype `complex`, `str`
3535

3636

37+
@dca.dataclass_array(broadcast=True, cast_dtype=True)
3738
@dataclasses.dataclass(frozen=True)
3839
class Point(dca.DataclassArray):
3940
x: f32['*shape']
4041
y: f32['*shape']
4142

4243

44+
@dca.dataclass_array(broadcast=True, cast_dtype=True)
4345
@dataclasses.dataclass(frozen=True)
4446
class PointWrapper(dca.DataclassArray):
4547
pts: Point
4648
rgb: f32['*shape 3']
4749

4850

51+
@dca.dataclass_array(broadcast=True, cast_dtype=True)
4952
@dataclasses.dataclass(frozen=True)
5053
class Isometrie(dca.DataclassArray):
5154
r: f32['... 3 3']
5255
t: i32[..., 2]
5356

5457

58+
@dca.dataclass_array(broadcast=True, cast_dtype=True)
5559
@dataclasses.dataclass(frozen=True)
5660
class Nested(dca.DataclassArray):
5761
# pytype: disable=annotation-type-mismatch
@@ -61,6 +65,7 @@ class Nested(dca.DataclassArray):
6165
# pytype: enable=annotation-type-mismatch
6266

6367

68+
@dca.dataclass_array(broadcast=True, cast_dtype=True)
6469
@dataclasses.dataclass(frozen=True)
6570
class WithStatic(dca.DataclassArray):
6671
"""Mix of static and array fields."""
@@ -541,8 +546,6 @@ def test_dataclass_params_no_cast(xnp: enp.NpModule):
541546

542547
@dataclasses.dataclass(frozen=True)
543548
class PointNoCast(dca.DataclassArray):
544-
__dca_params__ = dca.DataclassParams(cast_dtype=False)
545-
546549
x: FloatArray['*shape']
547550
y: IntArray['*shape']
548551

@@ -564,10 +567,9 @@ class PointNoCast(dca.DataclassArray):
564567
@enp.testing.parametrize_xnp()
565568
def test_dataclass_params_no_list(xnp: enp.NpModule):
566569

570+
@dca.dataclass_array(cast_list=False)
567571
@dataclasses.dataclass(frozen=True)
568572
class PointNoList(dca.DataclassArray):
569-
__dca_params__ = dca.DataclassParams(cast_list=False)
570-
571573
x: FloatArray['*shape']
572574
y: IntArray['*shape']
573575

@@ -583,15 +585,13 @@ def test_dataclass_params_no_broadcast(xnp: enp.NpModule):
583585

584586
@dataclasses.dataclass(frozen=True)
585587
class PointNoBroadcast(dca.DataclassArray):
586-
__dca_params__ = dca.DataclassParams(broadcast=False)
587-
588588
x: FloatArray['*shape']
589589
y: IntArray['*shape']
590590

591591
with pytest.raises(ValueError, match='Cannot broadcast'):
592592
PointNoBroadcast(
593593
x=xnp.array(1, dtype=np.float16),
594-
y=xnp.array([1, 2, 3], dtype=np.float16),
594+
y=xnp.array([1, 2, 3], dtype=np.int32),
595595
)
596596

597597

0 commit comments

Comments
 (0)