1717from __future__ import annotations
1818
1919import dataclasses
20+ import functools
2021import typing
2122from typing import Any , Callable , ClassVar , Generic , Iterator , Optional , Set , Tuple , Type , TypeVar , Union
2223
4849_IndicesArg = Union [_IndiceItem , _Indices ]
4950
5051_METADATA_KEY = 'dca_field'
52+ _DUMMY_ARRAY_FIELD = '_dca_dummy_array'
5153
5254
5355@dataclasses .dataclass (frozen = True )
@@ -192,7 +194,7 @@ class Square(DataclassArray):
192194 def __init_subclass__ (cls , ** kwargs ):
193195 super ().__init_subclass__ (** kwargs )
194196 # TODO(epot): Could have smart __repr__ which display types if array have
195- # too many values.
197+ # too many values (maybe directly in `edc.field(repr=...)`) .
196198 edc .dataclass (kw_only = True , repr = True )(cls )
197199 cls ._dca_tree_map_registered = False
198200 # Typing annotations have to be lazily evaluated (to support
@@ -210,24 +212,16 @@ def __post_init__(self) -> None:
210212 """Validate and normalize inputs."""
211213 cls = type (self )
212214
213- # Make sure the dataclass was registered and frozen
214- if not dataclasses .is_dataclass (cls ) or not cls .__dataclass_params__ .frozen : # pytype: disable=attribute-error
215- raise ValueError (
216- '`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
217- )
215+ # First time, we perform additional check & updates
216+ if cls ._dca_fields_metadata is None : # pylint: disable=protected-access
217+ _init_cls (self )
218218
219219 # Register the tree_map here instead of `__init_subclass__` as `jax` may
220220 # not have been registered yet during import
221221 if enp .lazy .has_jax and not cls ._dca_tree_map_registered : # pylint: disable=protected-access
222222 enp .lazy .jax .tree_util .register_pytree_node_class (cls )
223223 cls ._dca_tree_map_registered = True # pylint: disable=protected-access
224224
225- if not self ._all_array_fields :
226- raise ValueError (
227- f'{ self .__class__ .__qualname__ } should have at least one '
228- '`dca.array_field`'
229- )
230-
231225 # Validate and normalize array fields
232226 # * Maybe cast (list, np) -> xnp
233227 # * Maybe cast dtype
@@ -246,8 +240,9 @@ def __post_init__(self) -> None:
246240
247241 if xnp is None : # No values
248242 # Inside `jax.tree_utils`, tree-def can be created with `None` values.
243+ # Inside `jax.vmap`, tree can be created with `object()` sentinel values.
249244 assert shape is None
250- xnp = np
245+ xnp = None
251246
252247 # Cache results
253248 # Should the state be stored in a separate object to avoid collisions ?
@@ -418,6 +413,8 @@ def replace(self: _DcT, **kwargs: Any) -> _DcT:
418413 # Create the new object
419414 new_self = dataclasses .replace (self , ** init_kwargs )
420415
416+ # TODO(epot): Could try to unify logic bellow with `tree_unflatten`
417+
421418 # Additionally forward the non-init kwargs
422419 # `dataclasses.field(init=False) kwargs are required because `init=True`
423420 # creates conflicts:
@@ -449,7 +446,13 @@ def as_xnp(self: _DcT, xnp: enp.NpModule) -> _DcT:
449446 """Returns the instance as containing `xnp.ndarray`."""
450447 if xnp is self .xnp : # No-op
451448 return self
452- return self .map_field (xnp .asarray )
449+
450+ # Update all childs
451+ new_self = self ._map_field (
452+ array_fn = lambda f : xnp .asarray (f .value ),
453+ dc_fn = lambda f : f .value .as_xnp (xnp ),
454+ )
455+ return new_self
453456
454457 # ====== Internal ======
455458
@@ -474,53 +477,26 @@ def xnp(self) -> enp.NpModule:
474477 @epy .cached_property
475478 def _all_array_fields (self ) -> dict [str , _ArrayField ]:
476479 """All array fields, including `None` values."""
477- cls = type (self )
478-
479- # The first time, compute typing annotations & metadata
480- if cls ._dca_fields_metadata is None : # pylint: disable=protected-access
481- # At this point, `ForwardRef` should have been resolved.
482- try :
483- hints = typing_extensions .get_type_hints (cls , include_extras = True )
484- except Exception as e : # pylint: disable=broad-except
485- msg = (
486- f'Could not infer typing annotation of { cls .__qualname__ } '
487- f'defined in { cls .__module__ } :\n '
488- )
489- lines = [f' * { k } : { v !r} ' for k , v in cls .__annotations__ .items ()]
490- lines = '\n ' .join (lines )
491-
492- epy .reraise (e , prefix = msg + lines + '\n ' )
493-
494- dca_fields_metadata = {
495- f .name : _make_field_metadata (f , hints )
496- for f in dataclasses .fields (self )
497- }
498- cls ._dca_fields_metadata = ( # pylint: disable=protected-access
499- { # Filter `None` values
500- k : v for k , v in dca_fields_metadata .items () if v is not None
501- }
502- )
503-
504480 return { # pylint: disable=g-complex-comprehension
505481 name : _ArrayField (
506482 name = name ,
507483 host = self ,
508484 ** field_metadata .to_dict (), # pylint: disable=not-a-mapping
509485 )
510- for name , field_metadata in cls ._dca_fields_metadata .items () # pylint: disable=protected-access
486+ for name , field_metadata in self ._dca_fields_metadata .items () # pylint: disable=protected-access
511487 }
512488
513489 @epy .cached_property
514490 def _array_fields (self ) -> list [_ArrayField ]:
515- """All active array fields (non-None)."""
491+ """All active array fields (non-None), including static ones ."""
516492 # Filter `None` values
517493 return [
518494 f for f in self ._all_array_fields .values () if not f .is_value_missing
519495 ]
520496
521497 def _cast_xnp_dtype_inplace (self ) -> Optional [enp .NpModule ]:
522498 """Validate `xnp` are consistent and cast `np` -> `xnp` in-place."""
523- if not self ._array_fields : # No fields have been defined.
499+ if not self ._array_fields : # All fields are `None` / `object`
524500 return None
525501
526502 # Validate the dtype
@@ -566,6 +542,8 @@ def _cast_field(f: _ArrayField) -> None:
566542 def _broadcast_shape_inplace (self ) -> Optional [Shape ]:
567543 """Validate the shapes are consistent and broadcast values in-place."""
568544 if not self ._array_fields : # No fields have been defined.
545+ # This can be the case internally by jax which apply some
546+ # `tree_map(lambda x: sentinel)`.
569547 return None
570548
571549 # First collect all shapes and compute the final shape.
@@ -721,6 +699,109 @@ def assert_same_xnp(self, x: Union[Array[...], DataclassArray]) -> None:
721699 )
722700
723701
702+ def _init_cls (self : DataclassArray ) -> None :
703+ """Setup the class the first time the instance is called.
704+
705+ This will:
706+
707+ * Validate the `@dataclass(frozen=True)` is correctly applied
708+ * Extract the types annotations, detect which fields are arrays or static,
709+ and store the result in `_dca_fields_metadata`
710+ * For static `DataclassArray` (class with only static fields), it will
711+ add a dummy array field for compatibility with `.xnp`/`.shape` (so
712+ methods works correctly and return the right shape/xnp when nested)
713+
714+ Args:
715+ self: The dataclass to initialize
716+ """
717+ cls = type (self )
718+
719+ # Make sure the dataclass was registered and frozen
720+ if not dataclasses .is_dataclass (cls ) or not cls .__dataclass_params__ .frozen : # pytype: disable=attribute-error
721+ raise ValueError (
722+ '`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
723+ )
724+
725+ # The first time, compute typing annotations & metadata
726+ # At this point, `ForwardRef` should have been resolved.
727+ try :
728+ hints = typing_extensions .get_type_hints (cls , include_extras = True )
729+ except Exception as e : # pylint: disable=broad-except
730+ msg = (
731+ f'Could not infer typing annotation of { cls .__qualname__ } '
732+ f'defined in { cls .__module__ } :\n '
733+ )
734+ lines = [f' * { k } : { v !r} ' for k , v in cls .__annotations__ .items ()]
735+ lines = '\n ' .join (lines )
736+
737+ epy .reraise (e , prefix = msg + lines + '\n ' )
738+
739+ # TODO(epot): Remove restriction once pytype supports `datclass_transform`
740+ # and `dca` automatically apply the `@dataclasses.dataclass`
741+ if _DUMMY_ARRAY_FIELD in cls .__dataclass_fields__ : # pytype: disable=attribute-error
742+ raise NotImplementedError (
743+ 'Suclassing of DataclassArray with no array field is not supported '
744+ 'after an instance of the class was created. Error raised for '
745+ f'{ cls .__qualname__ } '
746+ )
747+
748+ dca_fields_metadata = {
749+ f .name : _make_field_metadata (f , hints ) for f in dataclasses .fields (cls )
750+ }
751+ dca_fields_metadata = { # Filter `None` values (static fields)
752+ k : v for k , v in dca_fields_metadata .items () if v is not None
753+ }
754+ if not dca_fields_metadata :
755+ # DataclassArray without any array fields
756+ # Hack: To support `.xnp`, `.shape`, we add a dummy empty field which
757+ # is propagated by the various ops.
758+ dca_fields_metadata [_DUMMY_ARRAY_FIELD ] = _ArrayFieldMetadata (
759+ inner_shape_non_static = (),
760+ dtype = np .float32 ,
761+ )
762+ default_dummy_array = np .zeros ((), dtype = np .float32 )
763+ _add_field_to_dataclass (
764+ cls , _DUMMY_ARRAY_FIELD , default = default_dummy_array
765+ )
766+ # Because we're in `__init__`, so also update the current call
767+ self ._setattr (_DUMMY_ARRAY_FIELD , default_dummy_array ) # pylint: disable=protected-access
768+
769+ cls ._dca_fields_metadata = ( # pylint: disable=protected-access
770+ dca_fields_metadata
771+ )
772+
773+
774+ def _add_field_to_dataclass (cls , name : str , default : Any ) -> None :
775+ """Add a new field to the given dataclass."""
776+ # Make sure to not update the parent class
777+ # Otherwise we could even accidentally update `dca.DataclassArray`
778+ if '__dataclass_fields__' not in cls .__dict__ :
779+ # TODO(epot): Remove the limitation once `dataclasses.dataclass` is
780+ # automatically applied
781+ raise ValueError (
782+ f'{ cls .__name__ } is not a `@dataclasses.dataclass(frozen=True)`'
783+ )
784+ assert name not in cls .__dataclass_fields__ # pytype: disable=attribute-error
785+
786+ # Ideally, we want init=False, so sub-dataclass ignore this field
787+ # but this makes `.replace` fail
788+ field = dataclasses .field (default = default , init = True , repr = False )
789+ field .__set_name__ (cls , name )
790+ field .name = name
791+ field .type = Any
792+ field ._field_type = dataclasses ._FIELD # pylint: disable=protected-access # pytype: disable=module-attr
793+ cls .__dataclass_fields__ [name ] = field # pytype: disable=attribute-error
794+
795+ original_init = cls .__init__
796+
797+ @functools .wraps (original_init )
798+ def new_init (self , ** kwargs : Any ):
799+ self ._setattr (name , kwargs .pop (name , default )) # pylint: disable=protected-access
800+ return original_init (self , ** kwargs )
801+
802+ cls .__init__ = new_init
803+
804+
724805def _infer_xnp (xnps : dict [enp .NpModule , list [str ]]) -> enp .NpModule :
725806 """Extract the `xnp` module."""
726807 non_np_xnps = set (xnps ) - {np } # jnp, tnp take precedence on `np`
@@ -913,7 +994,7 @@ def is_value_missing(self) -> bool:
913994 elif (
914995 isinstance (self .value , DataclassArray ) and not self .value ._array_fields # pylint: disable=protected-access
915996 ):
916- # Nested dataclass case (if all attributes are `None `, so no active
997+ # Nested dataclass case (if all attributes are `object `, so no active
917998 # array fields)
918999 return True
9191000 return False
0 commit comments