Skip to content

Commit 740bce0

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Add static dataclass_array support
PiperOrigin-RevId: 494705204
1 parent 94ae70e commit 740bce0

File tree

3 files changed

+197
-49
lines changed

3 files changed

+197
-49
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Added: Support for static `dca.DataclassArray` (dataclasses with only
27+
static fields).
28+
2629
## [1.2.1] - 2022-11-24
2730

2831
* Fixed: Compatibility with `edc.dataclass(auto_cast=True)` (fix the `'type'

dataclass_array/array_dataclass.py

Lines changed: 126 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import dataclasses
20+
import functools
2021
import typing
2122
from typing import Any, Callable, ClassVar, Generic, Iterator, Optional, Set, Tuple, Type, TypeVar, Union
2223

@@ -48,6 +49,7 @@
4849
_IndicesArg = Union[_IndiceItem, _Indices]
4950

5051
_METADATA_KEY = 'dca_field'
52+
_DUMMY_ARRAY_FIELD = '_dca_dummy_array'
5153

5254

5355
@dataclasses.dataclass(frozen=True)
@@ -192,7 +194,7 @@ class Square(DataclassArray):
192194
def __init_subclass__(cls, **kwargs):
193195
super().__init_subclass__(**kwargs)
194196
# TODO(epot): Could have smart __repr__ which display types if array have
195-
# too many values.
197+
# too many values (maybe directly in `edc.field(repr=...)`).
196198
edc.dataclass(kw_only=True, repr=True)(cls)
197199
cls._dca_tree_map_registered = False
198200
# Typing annotations have to be lazily evaluated (to support
@@ -210,24 +212,16 @@ def __post_init__(self) -> None:
210212
"""Validate and normalize inputs."""
211213
cls = type(self)
212214

213-
# Make sure the dataclass was registered and frozen
214-
if not dataclasses.is_dataclass(cls) or not cls.__dataclass_params__.frozen: # pytype: disable=attribute-error
215-
raise ValueError(
216-
'`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
217-
)
215+
# First time, we perform additional check & updates
216+
if cls._dca_fields_metadata is None: # pylint: disable=protected-access
217+
_init_cls(self)
218218

219219
# Register the tree_map here instead of `__init_subclass__` as `jax` may
220220
# not have been registered yet during import
221221
if enp.lazy.has_jax and not cls._dca_tree_map_registered: # pylint: disable=protected-access
222222
enp.lazy.jax.tree_util.register_pytree_node_class(cls)
223223
cls._dca_tree_map_registered = True # pylint: disable=protected-access
224224

225-
if not self._all_array_fields:
226-
raise ValueError(
227-
f'{self.__class__.__qualname__} should have at least one '
228-
'`dca.array_field`'
229-
)
230-
231225
# Validate and normalize array fields
232226
# * Maybe cast (list, np) -> xnp
233227
# * Maybe cast dtype
@@ -246,8 +240,9 @@ def __post_init__(self) -> None:
246240

247241
if xnp is None: # No values
248242
# Inside `jax.tree_utils`, tree-def can be created with `None` values.
243+
# Inside `jax.vmap`, tree can be created with `object()` sentinel values.
249244
assert shape is None
250-
xnp = np
245+
xnp = None
251246

252247
# Cache results
253248
# Should the state be stored in a separate object to avoid collisions ?
@@ -418,6 +413,8 @@ def replace(self: _DcT, **kwargs: Any) -> _DcT:
418413
# Create the new object
419414
new_self = dataclasses.replace(self, **init_kwargs)
420415

416+
# TODO(epot): Could try to unify logic bellow with `tree_unflatten`
417+
421418
# Additionally forward the non-init kwargs
422419
# `dataclasses.field(init=False) kwargs are required because `init=True`
423420
# creates conflicts:
@@ -449,7 +446,13 @@ def as_xnp(self: _DcT, xnp: enp.NpModule) -> _DcT:
449446
"""Returns the instance as containing `xnp.ndarray`."""
450447
if xnp is self.xnp: # No-op
451448
return self
452-
return self.map_field(xnp.asarray)
449+
450+
# Update all childs
451+
new_self = self._map_field(
452+
array_fn=lambda f: xnp.asarray(f.value),
453+
dc_fn=lambda f: f.value.as_xnp(xnp),
454+
)
455+
return new_self
453456

454457
# ====== Internal ======
455458

@@ -474,53 +477,26 @@ def xnp(self) -> enp.NpModule:
474477
@epy.cached_property
475478
def _all_array_fields(self) -> dict[str, _ArrayField]:
476479
"""All array fields, including `None` values."""
477-
cls = type(self)
478-
479-
# The first time, compute typing annotations & metadata
480-
if cls._dca_fields_metadata is None: # pylint: disable=protected-access
481-
# At this point, `ForwardRef` should have been resolved.
482-
try:
483-
hints = typing_extensions.get_type_hints(cls, include_extras=True)
484-
except Exception as e: # pylint: disable=broad-except
485-
msg = (
486-
f'Could not infer typing annotation of {cls.__qualname__} '
487-
f'defined in {cls.__module__}:\n'
488-
)
489-
lines = [f' * {k}: {v!r}' for k, v in cls.__annotations__.items()]
490-
lines = '\n'.join(lines)
491-
492-
epy.reraise(e, prefix=msg + lines + '\n')
493-
494-
dca_fields_metadata = {
495-
f.name: _make_field_metadata(f, hints)
496-
for f in dataclasses.fields(self)
497-
}
498-
cls._dca_fields_metadata = ( # pylint: disable=protected-access
499-
{ # Filter `None` values
500-
k: v for k, v in dca_fields_metadata.items() if v is not None
501-
}
502-
)
503-
504480
return { # pylint: disable=g-complex-comprehension
505481
name: _ArrayField(
506482
name=name,
507483
host=self,
508484
**field_metadata.to_dict(), # pylint: disable=not-a-mapping
509485
)
510-
for name, field_metadata in cls._dca_fields_metadata.items() # pylint: disable=protected-access
486+
for name, field_metadata in self._dca_fields_metadata.items() # pylint: disable=protected-access
511487
}
512488

513489
@epy.cached_property
514490
def _array_fields(self) -> list[_ArrayField]:
515-
"""All active array fields (non-None)."""
491+
"""All active array fields (non-None), including static ones."""
516492
# Filter `None` values
517493
return [
518494
f for f in self._all_array_fields.values() if not f.is_value_missing
519495
]
520496

521497
def _cast_xnp_dtype_inplace(self) -> Optional[enp.NpModule]:
522498
"""Validate `xnp` are consistent and cast `np` -> `xnp` in-place."""
523-
if not self._array_fields: # No fields have been defined.
499+
if not self._array_fields: # All fields are `None` / `object`
524500
return None
525501

526502
# Validate the dtype
@@ -566,6 +542,8 @@ def _cast_field(f: _ArrayField) -> None:
566542
def _broadcast_shape_inplace(self) -> Optional[Shape]:
567543
"""Validate the shapes are consistent and broadcast values in-place."""
568544
if not self._array_fields: # No fields have been defined.
545+
# This can be the case internally by jax which apply some
546+
# `tree_map(lambda x: sentinel)`.
569547
return None
570548

571549
# First collect all shapes and compute the final shape.
@@ -721,6 +699,109 @@ def assert_same_xnp(self, x: Union[Array[...], DataclassArray]) -> None:
721699
)
722700

723701

702+
def _init_cls(self: DataclassArray) -> None:
703+
"""Setup the class the first time the instance is called.
704+
705+
This will:
706+
707+
* Validate the `@dataclass(frozen=True)` is correctly applied
708+
* Extract the types annotations, detect which fields are arrays or static,
709+
and store the result in `_dca_fields_metadata`
710+
* For static `DataclassArray` (class with only static fields), it will
711+
add a dummy array field for compatibility with `.xnp`/`.shape` (so
712+
methods works correctly and return the right shape/xnp when nested)
713+
714+
Args:
715+
self: The dataclass to initialize
716+
"""
717+
cls = type(self)
718+
719+
# Make sure the dataclass was registered and frozen
720+
if not dataclasses.is_dataclass(cls) or not cls.__dataclass_params__.frozen: # pytype: disable=attribute-error
721+
raise ValueError(
722+
'`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
723+
)
724+
725+
# The first time, compute typing annotations & metadata
726+
# At this point, `ForwardRef` should have been resolved.
727+
try:
728+
hints = typing_extensions.get_type_hints(cls, include_extras=True)
729+
except Exception as e: # pylint: disable=broad-except
730+
msg = (
731+
f'Could not infer typing annotation of {cls.__qualname__} '
732+
f'defined in {cls.__module__}:\n'
733+
)
734+
lines = [f' * {k}: {v!r}' for k, v in cls.__annotations__.items()]
735+
lines = '\n'.join(lines)
736+
737+
epy.reraise(e, prefix=msg + lines + '\n')
738+
739+
# TODO(epot): Remove restriction once pytype supports `datclass_transform`
740+
# and `dca` automatically apply the `@dataclasses.dataclass`
741+
if _DUMMY_ARRAY_FIELD in cls.__dataclass_fields__: # pytype: disable=attribute-error
742+
raise NotImplementedError(
743+
'Suclassing of DataclassArray with no array field is not supported '
744+
'after an instance of the class was created. Error raised for '
745+
f'{cls.__qualname__}'
746+
)
747+
748+
dca_fields_metadata = {
749+
f.name: _make_field_metadata(f, hints) for f in dataclasses.fields(cls)
750+
}
751+
dca_fields_metadata = { # Filter `None` values (static fields)
752+
k: v for k, v in dca_fields_metadata.items() if v is not None
753+
}
754+
if not dca_fields_metadata:
755+
# DataclassArray without any array fields
756+
# Hack: To support `.xnp`, `.shape`, we add a dummy empty field which
757+
# is propagated by the various ops.
758+
dca_fields_metadata[_DUMMY_ARRAY_FIELD] = _ArrayFieldMetadata(
759+
inner_shape_non_static=(),
760+
dtype=np.float32,
761+
)
762+
default_dummy_array = np.zeros((), dtype=np.float32)
763+
_add_field_to_dataclass(
764+
cls, _DUMMY_ARRAY_FIELD, default=default_dummy_array
765+
)
766+
# Because we're in `__init__`, so also update the current call
767+
self._setattr(_DUMMY_ARRAY_FIELD, default_dummy_array) # pylint: disable=protected-access
768+
769+
cls._dca_fields_metadata = ( # pylint: disable=protected-access
770+
dca_fields_metadata
771+
)
772+
773+
774+
def _add_field_to_dataclass(cls, name: str, default: Any) -> None:
775+
"""Add a new field to the given dataclass."""
776+
# Make sure to not update the parent class
777+
# Otherwise we could even accidentally update `dca.DataclassArray`
778+
if '__dataclass_fields__' not in cls.__dict__:
779+
# TODO(epot): Remove the limitation once `dataclasses.dataclass` is
780+
# automatically applied
781+
raise ValueError(
782+
f'{cls.__name__} is not a `@dataclasses.dataclass(frozen=True)`'
783+
)
784+
assert name not in cls.__dataclass_fields__ # pytype: disable=attribute-error
785+
786+
# Ideally, we want init=False, so sub-dataclass ignore this field
787+
# but this makes `.replace` fail
788+
field = dataclasses.field(default=default, init=True, repr=False)
789+
field.__set_name__(cls, name)
790+
field.name = name
791+
field.type = Any
792+
field._field_type = dataclasses._FIELD # pylint: disable=protected-access # pytype: disable=module-attr
793+
cls.__dataclass_fields__[name] = field # pytype: disable=attribute-error
794+
795+
original_init = cls.__init__
796+
797+
@functools.wraps(original_init)
798+
def new_init(self, **kwargs: Any):
799+
self._setattr(name, kwargs.pop(name, default)) # pylint: disable=protected-access
800+
return original_init(self, **kwargs)
801+
802+
cls.__init__ = new_init
803+
804+
724805
def _infer_xnp(xnps: dict[enp.NpModule, list[str]]) -> enp.NpModule:
725806
"""Extract the `xnp` module."""
726807
non_np_xnps = set(xnps) - {np} # jnp, tnp take precedence on `np`
@@ -913,7 +994,7 @@ def is_value_missing(self) -> bool:
913994
elif (
914995
isinstance(self.value, DataclassArray) and not self.value._array_fields # pylint: disable=protected-access
915996
):
916-
# Nested dataclass case (if all attributes are `None`, so no active
997+
# Nested dataclass case (if all attributes are `object`, so no active
917998
# array fields)
918999
return True
9191000
return False

0 commit comments

Comments
 (0)