@@ -86,7 +86,6 @@ def dataclass_array(
8686
8787 ```python
8888 @dca.dataclass_array()
89- @dataclasses.dataclass(frozen=True)
9089 class MyDataclass(dca.DataclassArray):
9190 ...
9291 ```
@@ -119,6 +118,31 @@ def decorator(cls):
119118 return decorator
120119
121120
121+ def array_field (
122+ shape : Shape ,
123+ dtype : DTypeArg = float ,
124+ ** field_kwargs ,
125+ ) -> dataclasses .Field [DcOrArray ]:
126+ """Dataclass array field.
127+
128+ See `dca.DataclassArray` for example.
129+
130+ Args:
131+ shape: Inner shape of the field
132+ dtype: Type of the field
133+ **field_kwargs: Args forwarded to `dataclasses.field`
134+
135+ Returns:
136+ The dataclass field.
137+ """
138+ # TODO(epot): Validate shape, dtype
139+ dca_field = _ArrayFieldMetadata (
140+ inner_shape_non_static = shape ,
141+ dtype = dtype ,
142+ )
143+ return dataclasses .field (** field_kwargs , metadata = {_METADATA_KEY : dca_field })
144+
145+
122146class MetaDataclassArray (type ):
123147 """DataclassArray metaclass."""
124148
@@ -128,13 +152,21 @@ def __getitem__(cls, spec):
128152 return Annotated [cls , field_utils .ShapeAnnotation (spec )]
129153
130154
155+ @typing_extensions .dataclass_transform ( # pytype: disable=not-supported-yet
156+ kw_only_default = True ,
157+ # TODO(b/272524683):Restore field specifier
158+ # field_specifiers=(
159+ # dataclasses.Field,
160+ # dataclasses.field,
161+ # array_field,
162+ # ),
163+ )
131164class DataclassArray (metaclass = MetaDataclassArray ):
132165 """Dataclass which behaves like an array.
133166
134167 Usage:
135168
136169 ```python
137- @dataclasses.dataclass
138170 class Square(DataclassArray):
139171 pos: f32['*shape 2']
140172 scale: f32['*shape']
@@ -179,7 +211,6 @@ class Square(DataclassArray):
179211
180212 Field which do not satisfy any of the above conditions are static (including
181213 field annotated with `field: np.ndarray` or similar).
182-
183214 """
184215
185216 # Child class inherit the default params by default, but can also
@@ -194,8 +225,21 @@ class Square(DataclassArray):
194225 _shape : Shape
195226 _xnp : enp .NpModule
196227
197- def __init_subclass__ (cls , ** kwargs ):
228+ def __init_subclass__ (
229+ cls ,
230+ frozen = True ,
231+ ** kwargs ,
232+ ):
198233 super ().__init_subclass__ (** kwargs )
234+
235+ if not frozen :
236+ raise ValueError (f'{ cls } cannot be `frozen=False`.' )
237+
238+ # Apply dataclass (in-place)
239+ if not typing .TYPE_CHECKING :
240+ # TODO(b/227290126): Create pytype issues
241+ dataclasses .dataclass (frozen = True )(cls )
242+
199243 # TODO(epot): Could have smart __repr__ which display types if array have
200244 # too many values (maybe directly in `edc.field(repr=...)`).
201245 edc .dataclass (kw_only = True , repr = True , auto_cast = False )(cls )
@@ -212,6 +256,11 @@ def __init_subclass__(cls, **kwargs):
212256 # `__dca_non_init_fields__` (fields should be merged from `.mro()`)
213257 cls .__dca_non_init_fields__ = set (cls .__dca_non_init_fields__ )
214258
259+ if typing .TYPE_CHECKING :
260+ # TODO(b/242839979): pytype do not support PEP 681 -- Data Class Transforms
261+ def __init__ (self , ** kwargs ):
262+ pass
263+
215264 def __post_init__ (self ) -> None :
216265 """Validate and normalize inputs."""
217266 cls = type (self )
@@ -755,7 +804,6 @@ def _init_cls(self: DataclassArray) -> None:
755804
756805 This will:
757806
758- * Validate the `@dataclass(frozen=True)` is correctly applied
759807 * Extract the types annotations, detect which fields are arrays or static,
760808 and store the result in `_dca_fields_metadata`
761809 * For static `DataclassArray` (class with only static fields), it will
@@ -767,12 +815,6 @@ def _init_cls(self: DataclassArray) -> None:
767815 """
768816 cls = type (self )
769817
770- # Make sure the dataclass was registered and frozen
771- if not dataclasses .is_dataclass (cls ) or not cls .__dataclass_params__ .frozen : # pytype: disable=attribute-error
772- raise ValueError (
773- '`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
774- )
775-
776818 # The first time, compute typing annotations & metadata
777819 # At this point, `ForwardRef` should have been resolved.
778820 try :
@@ -797,7 +839,7 @@ def _init_cls(self: DataclassArray) -> None:
797839 )
798840
799841 dca_fields_metadata = {
800- f .name : _make_field_metadata (f , hints ) for f in dataclasses .fields (cls )
842+ f .name : _make_field_metadata (f , hints ) for f in dataclasses .fields (cls ) # pytype: disable=wrong-arg-types
801843 }
802844 dca_fields_metadata = { # Filter `None` values (static fields)
803845 k : v for k , v in dca_fields_metadata .items () if v is not None
@@ -909,31 +951,6 @@ class _TreeMetadata:
909951 non_array_field_kwargs : dict [str , Any ]
910952
911953
912- def array_field (
913- shape : Shape ,
914- dtype : DTypeArg = float ,
915- ** field_kwargs ,
916- ) -> dataclasses .Field [DcOrArray ]:
917- """Dataclass array field.
918-
919- See `dca.DataclassArray` for example.
920-
921- Args:
922- shape: Inner shape of the field
923- dtype: Type of the field
924- **field_kwargs: Args forwarded to `dataclasses.field`
925-
926- Returns:
927- The dataclass field.
928- """
929- # TODO(epot): Validate shape, dtype
930- dca_field = _ArrayFieldMetadata (
931- inner_shape_non_static = shape ,
932- dtype = dtype ,
933- )
934- return dataclasses .field (** field_kwargs , metadata = {_METADATA_KEY : dca_field })
935-
936-
937954# TODO(epot): Should refactor `_ArrayField` in `_DataclassArrayField` and
938955# `_ArrayField` depending on whether dtype is `DataclassArray` or not.
939956# Alternativelly, maybe should create a `DcArrayDType` dtype instead.
0 commit comments