1818
1919import dataclasses
2020import typing
21- from typing import Any , Callable , ClassVar , Generic , Iterator , Optional , Tuple , Type , TypeVar , Union
21+ from typing import Any , Callable , ClassVar , Generic , Iterator , Optional , Set , Tuple , Type , TypeVar , Union
2222
2323from dataclass_array import field_utils
2424from dataclass_array import shape_parsing
3636import typing_extensions
3737from typing_extensions import Annotated , Literal , TypeAlias # pylint: disable=g-multiple-import
3838
39-
4039lazy = enp .lazy
4140
4241# TODO(pytype): Should use `dca.typing.DcT` but bound does not work across
@@ -133,6 +132,11 @@ class Square(DataclassArray):
133132 # overwrite them.
134133 __dca_params__ : ClassVar [DataclassParams ] = DataclassParams ()
135134
135+ # TODO(epot): Could be removed with py3.10 and using `kw_only=True`
136+ # Fields defined here will be forwarded with `.replace`
137+ # TODO(py39): Replace Set -> set
138+ __dca_non_init_fields__ : ClassVar [Set [str ]] = set ()
139+
136140 _shape : Shape
137141 _xnp : enp .NpModule
138142
@@ -148,6 +152,9 @@ def __init_subclass__(cls, **kwargs):
148152 # convertions, we cache the type annotations here.
149153 cls ._dca_fields_metadata : Optional [dict [str , _ArrayFieldMetadata ]] = None
150154
155+ # Normalize the `cls.__dca_non_init_fields__`
156+ cls .__dca_non_init_fields__ = set (cls .__dca_non_init_fields__ )
157+
151158 def __post_init__ (self ) -> None :
152159 """Validate and normalize inputs."""
153160 cls = type (self )
@@ -346,8 +353,32 @@ def map_field(
346353
347354 # ====== Dataclass/Conversion utils ======
348355
349- # TODO(pytype): Could be removed once there's a way of annotating this.
350- replace = edc .dataclass_utils .replace
356+ def replace (self : _DcT , ** kwargs : Any ) -> _DcT :
357+ """Alias for `dataclasses.replace`."""
358+ init_kwargs = {
359+ k : v for k , v in kwargs .items () if k not in self .__dca_non_init_fields__
360+ }
361+ non_init_kwargs = {
362+ k : v for k , v in kwargs .items () if k in self .__dca_non_init_fields__
363+ }
364+
365+ # Create the new object
366+ new_self = dataclasses .replace (self , ** init_kwargs )
367+
368+ # Additionally forward the non-init kwargs
369+ # `dataclasses.field(init=False) kwargs are required because `init=True`
370+ # creates conflicts:
371+ # * Inheritance fails with non-default argument 'K' follows default argument
372+ # * Pytype complains too
373+ # TODO(py310): Cleanup using `dataclass(kw_only)`
374+ assert new_self is not self
375+ for k in self .__dca_non_init_fields__ :
376+ if k in non_init_kwargs :
377+ v = non_init_kwargs [k ]
378+ else :
379+ v = getattr (self , k )
380+ new_self ._setattr (k , v ) # pylint: disable=protected-access
381+ return new_self
351382
352383 def as_np (self : _DcT ) -> _DcT :
353384 """Returns the instance as containing `np.ndarray`."""
@@ -398,10 +429,12 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
398429 try :
399430 hints = typing_extensions .get_type_hints (cls , include_extras = True )
400431 except Exception as e : # pylint: disable=broad-except
401- epy .reraise (
402- e ,
403- prefix = f'Could not infer typing annotation of { cls .__name__ } '
404- f'defined in { cls .__module__ } ' )
432+ msg = (f'Could not infer typing annotation of { cls .__qualname__ } '
433+ f'defined in { cls .__module__ } :\n ' )
434+ lines = [f' * { k } : { v !r} ' for k , v in cls .__annotations__ .items ()]
435+ lines = '\n ' .join (lines )
436+
437+ epy .reraise (e , prefix = msg + lines + '\n ' )
405438
406439 dca_fields_metadata = {
407440 f .name : _make_field_metadata (f , hints )
@@ -603,11 +636,14 @@ def tree_unflatten(
603636 self = cls (** array_field_kwargs , ** init_fields )
604637 # Currently it's not clear how to handle non-init fields so raise an error
605638 if non_init_fields :
606- if set (non_init_fields ) != { 'fig_config' } :
639+ if set (non_init_fields ) - self . __dca_non_init_fields__ :
607640 raise ValueError (
608- '`dca.DataclassArray` with init=False field not supported yet.' )
641+ '`dca.DataclassArray` field with init=False should be explicitly '
642+ 'specified in `__dca_non_init_fields__` for them to be '
643+ 'propagated by `tree_map`.' )
609644 # TODO(py310): Delete once dataclass supports `kw_only=True`
610- self ._setattr ('fig_config' , non_init_fields ['fig_config' ]) # pylint: disable=protected-access
645+ for k , v in non_init_fields .items ():
646+ self ._setattr (k , v ) # pylint: disable=protected-access
611647 return self
612648
613649 def _setattr (self , name : str , value : Any ) -> None :
0 commit comments