Skip to content

Commit c0c8308

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Forward non-init fields in .replace, vectorized_method, tree_map
Fix `cam.fig_config` not correctly forwarded PiperOrigin-RevId: 469680739
1 parent 7feefc1 commit c0c8308

File tree

3 files changed

+99
-11
lines changed

3 files changed

+99
-11
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2525

2626
* Added: Syntax to specify the shape of the DataclassArray (e.g. `MyRay['h
2727
w']`).
28+
* Fixed: Correctly forward non-init fields in `.replace`, `tree_map`,
29+
`@dca.vectorize_method`
2830

2931
## [1.0.0] - 2022-08-08
3032

dataclass_array/array_dataclass.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import dataclasses
2020
import typing
21-
from typing import Any, Callable, ClassVar, Generic, Iterator, Optional, Tuple, Type, TypeVar, Union
21+
from typing import Any, Callable, ClassVar, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union
2222

2323
from dataclass_array import field_utils
2424
from dataclass_array import shape_parsing
@@ -36,7 +36,6 @@
3636
import typing_extensions
3737
from typing_extensions import Annotated, Literal, TypeAlias # pylint: disable=g-multiple-import
3838

39-
4039
lazy = enp.lazy
4140

4241
# TODO(pytype): Should use `dca.typing.DcT` but bound does not work across
@@ -133,6 +132,11 @@ class Square(DataclassArray):
133132
# overwrite them.
134133
__dca_params__: ClassVar[DataclassParams] = DataclassParams()
135134

135+
# TODO(epot): Could be removed with py3.10 and using `kw_only=True`
136+
# Fields defined here will be forwarded with `.replace`
137+
# TODO(py39): Replace Set -> set
138+
__dca_non_init_fields__: ClassVar[Set[str]] = set()
139+
136140
_shape: Shape
137141
_xnp: enp.NpModule
138142

@@ -148,6 +152,9 @@ def __init_subclass__(cls, **kwargs):
148152
# convertions, we cache the type annotations here.
149153
cls._dca_fields_metadata: Optional[dict[str, _ArrayFieldMetadata]] = None
150154

155+
# Normalize the `cls.__dca_non_init_fields__`
156+
cls.__dca_non_init_fields__ = set(cls.__dca_non_init_fields__)
157+
151158
def __post_init__(self) -> None:
152159
"""Validate and normalize inputs."""
153160
cls = type(self)
@@ -346,8 +353,32 @@ def map_field(
346353

347354
# ====== Dataclass/Conversion utils ======
348355

349-
# TODO(pytype): Could be removed once there's a way of annotating this.
350-
replace = edc.dataclass_utils.replace
356+
def replace(self: _DcT, **kwargs: Any) -> _DcT:
357+
"""Alias for `dataclasses.replace`."""
358+
init_kwargs = {
359+
k: v for k, v in kwargs.items() if k not in self.__dca_non_init_fields__
360+
}
361+
non_init_kwargs = {
362+
k: v for k, v in kwargs.items() if k in self.__dca_non_init_fields__
363+
}
364+
365+
# Create the new object
366+
new_self = dataclasses.replace(self, **init_kwargs)
367+
368+
# Additionally forward the non-init kwargs
369+
# `dataclasses.field(init=False) kwargs are required because `init=True`
370+
# creates conflicts:
371+
# * Inheritance fails with non-default argument 'K' follows default argument
372+
# * Pytype complains too
373+
# TODO(py310): Cleanup using `dataclass(kw_only)`
374+
assert new_self is not self
375+
for k in self.__dca_non_init_fields__:
376+
if k in non_init_kwargs:
377+
v = non_init_kwargs[k]
378+
else:
379+
v = getattr(self, k)
380+
new_self._setattr(k, v) # pylint: disable=protected-access
381+
return new_self
351382

352383
def as_np(self: _DcT) -> _DcT:
353384
"""Returns the instance as containing `np.ndarray`."""
@@ -398,10 +429,12 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
398429
try:
399430
hints = typing_extensions.get_type_hints(cls, include_extras=True)
400431
except Exception as e: # pylint: disable=broad-except
401-
epy.reraise(
402-
e,
403-
prefix=f'Could not infer typing annotation of {cls.__name__} '
404-
f'defined in {cls.__module__}')
432+
msg = (f'Could not infer typing annotation of {cls.__qualname__} '
433+
f'defined in {cls.__module__}:\n')
434+
lines = [f' * {k}: {v!r}' for k, v in cls.__annotations__.items()]
435+
lines = '\n'.join(lines)
436+
437+
epy.reraise(e, prefix=msg + lines + '\n')
405438

406439
dca_fields_metadata = {
407440
f.name: _make_field_metadata(f, hints)
@@ -603,11 +636,14 @@ def tree_unflatten(
603636
self = cls(**array_field_kwargs, **init_fields)
604637
# Currently it's not clear how to handle non-init fields so raise an error
605638
if non_init_fields:
606-
if set(non_init_fields) != {'fig_config'}:
639+
if set(non_init_fields) - self.__dca_non_init_fields__:
607640
raise ValueError(
608-
'`dca.DataclassArray` with init=False field not supported yet.')
641+
'`dca.DataclassArray` field with init=False should be explicitly '
642+
'specified in `__dca_non_init_fields__` for them to be '
643+
'propagated by `tree_map`.')
609644
# TODO(py310): Delete once dataclass supports `kw_only=True`
610-
self._setattr('fig_config', non_init_fields['fig_config']) # pylint: disable=protected-access
645+
for k, v in non_init_fields.items():
646+
self._setattr(k, v) # pylint: disable=protected-access
611647
return self
612648

613649
def _setattr(self, name: str, value: Any) -> None:

dataclass_array/vectorization_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@
1616

1717
from __future__ import annotations
1818

19+
import dataclasses
20+
1921
import dataclass_array as dca
2022
from dataclass_array import vectorization
2123
from dataclass_array.utils import inspect_utils
2224
from dataclass_array.utils import np_utils
2325
from etils import enp
26+
from etils.array_types import FloatArray
27+
import jax
2428
import pytest
29+
import tensorflow.experimental.numpy as tnp
2530

2631
H = 2
2732
W = 3
@@ -141,3 +146,48 @@ def fn(self, arg):
141146
bound_args,
142147
map_non_static=lambda fn, args: args.map(fn),
143148
)
149+
150+
151+
@enp.testing.parametrize_xnp()
152+
def test_replace_dca(xnp: enp.NpModule):
153+
154+
# Ensure that the non-init static fields are correctly forwarded.
155+
156+
@dataclasses.dataclass(frozen=True)
157+
class DataclassWithNonInit(dca.DataclassArray):
158+
"""Dataclass with a non-init (static) field."""
159+
__dca_non_init_fields__ = ('x',)
160+
161+
y: FloatArray['*batch']
162+
x: int = dataclasses.field(default=1, init=False)
163+
164+
@dca.vectorize_method
165+
def fn(self):
166+
assert not self.shape
167+
assert self.x == 5
168+
return self
169+
170+
a = DataclassWithNonInit(y=[1, 0, 0]).as_xnp(xnp)
171+
assert a.shape == (3,)
172+
assert a.x == 1
173+
174+
# Replace supported
175+
a = a.replace(x=5)
176+
assert a.shape == (3,)
177+
assert a.x == 5
178+
179+
a = a.replace(y=a.y + 1)
180+
assert a.shape == (3,)
181+
assert a.x == 5
182+
183+
# Vectorization supported
184+
if xnp != tnp:
185+
a = a.fn()
186+
assert a.xnp is xnp
187+
assert a.shape == (3,)
188+
assert a.x == 5
189+
190+
# Tree-map supported
191+
a = jax.tree_util.tree_map(lambda x: x, a)
192+
assert a.shape == (3,)
193+
assert a.x == 5

0 commit comments

Comments
 (0)