@@ -61,6 +61,7 @@ class DataclassParams:
6161 cast_dtype: If `True`, auto-cast inputs `dtype`
6262 cast_list: If `True`, auto-cast lists to `xnp.ndarray`
6363 """
64+
6465 # If modifying this, make sure to modify `@dataclass_array` too!
6566 broadcast : bool = False
6667 cast_dtype : bool = False
@@ -101,7 +102,8 @@ def decorator(cls):
101102 if not issubclass (cls , DataclassArray ):
102103 raise TypeError (
103104 '`@dca.dataclass_array` can only be applied on `dca.DataclassArray`. '
104- f'Got: { cls } ' )
105+ f'Got: { cls } '
106+ )
105107 cls .__dca_params__ = DataclassParams (
106108 broadcast = broadcast ,
107109 cast_dtype = cast_dtype ,
@@ -174,6 +176,7 @@ class Square(DataclassArray):
174176 field annotated with `field: np.ndarray` or similar).
175177
176178 """
179+
177180 # Child class inherit the default params by default, but can also
178181 # overwrite them.
179182 __dca_params__ : ClassVar [DataclassParams ] = DataclassParams ()
@@ -208,7 +211,8 @@ def __post_init__(self) -> None:
208211 # Make sure the dataclass was registered and frozen
209212 if not dataclasses .is_dataclass (cls ) or not cls .__dataclass_params__ .frozen : # pytype: disable=attribute-error
210213 raise ValueError (
211- '`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)' )
214+ '`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
215+ )
212216
213217 # Register the tree_map here instead of `__init_subclass__` as `jax` may
214218 # not have been registered yet during import
@@ -219,7 +223,8 @@ def __post_init__(self) -> None:
219223 if not self ._all_array_fields :
220224 raise ValueError (
221225 f'{ self .__class__ .__qualname__ } should have at least one '
222- '`dca.array_field`' )
226+ '`dca.array_field`'
227+ )
223228
224229 # Validate and normalize array fields
225230 # * Maybe cast (list, np) -> xnp
@@ -309,12 +314,10 @@ def flatten(self: _DcT) -> _DcT:
309314
310315 def broadcast_to (self : _DcT , shape : Shape ) -> _DcT :
311316 """Broadcast the batch shape."""
312- # pyformat: disable
313317 return self ._map_field (
314318 array_fn = lambda f : f .broadcast_to (shape ),
315319 dc_fn = lambda f : f .broadcast_to (shape ),
316320 )
317- # pyformat: enable
318321
319322 def __getitem__ (self : _DcT , indices : _IndicesArg ) -> _DcT :
320323 """Slice indexing."""
@@ -342,7 +345,8 @@ def __len__(self) -> int:
342345 """Length of the first array dimension."""
343346 if not self .shape :
344347 raise TypeError (
345- f'len() of unsized { self .__class__ .__name__ } (shape={ self .shape } )' )
348+ f'len() of unsized { self .__class__ .__name__ } (shape={ self .shape } )'
349+ )
346350 return self .shape [0 ]
347351
348352 def __bool__ (self ) -> Literal [True ]:
@@ -384,7 +388,8 @@ def fn(ray: Optional[dca.Ray] = None):
384388 if self .shape and not len (self ): # pylint: disable=g-explicit-length-test
385389 raise ValueError (
386390 f'The truth value of { self .__class__ .__name__ } when `len(x) == 0` '
387- 'is ambigous. Use `len(x)` or `x is not None`.' )
391+ 'is ambigous. Use `len(x)` or `x is not None`.'
392+ )
388393 return True
389394
390395 def map_field (
@@ -475,8 +480,10 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
475480 try :
476481 hints = typing_extensions .get_type_hints (cls , include_extras = True )
477482 except Exception as e : # pylint: disable=broad-except
478- msg = (f'Could not infer typing annotation of { cls .__qualname__ } '
479- f'defined in { cls .__module__ } :\n ' )
483+ msg = (
484+ f'Could not infer typing annotation of { cls .__qualname__ } '
485+ f'defined in { cls .__module__ } :\n '
486+ )
480487 lines = [f' * { k } : { v !r} ' for k , v in cls .__annotations__ .items ()]
481488 lines = '\n ' .join (lines )
482489
@@ -486,16 +493,19 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
486493 f .name : _make_field_metadata (f , hints )
487494 for f in dataclasses .fields (self )
488495 }
489- cls ._dca_fields_metadata = { # Filter `None` values # pylint: disable=protected-access
490- k : v for k , v in dca_fields_metadata .items () if v is not None
491- }
496+ cls ._dca_fields_metadata = ( # pylint: disable=protected-access
497+ { # Filter `None` values
498+ k : v for k , v in dca_fields_metadata .items () if v is not None
499+ }
500+ )
492501
493502 return { # pylint: disable=g-complex-comprehension
494503 name : _ArrayField (
495504 name = name ,
496505 host = self ,
497506 ** field_metadata .to_dict (), # pylint: disable=not-a-mapping
498- ) for name , field_metadata in cls ._dca_fields_metadata .items () # pylint: disable=protected-access
507+ )
508+ for name , field_metadata in cls ._dca_fields_metadata .items () # pylint: disable=protected-access
499509 }
500510
501511 @epy .cached_property
@@ -573,18 +583,17 @@ def _broadcast_shape_inplace(self) -> Optional[Shape]:
573583 # Currently, we restrict broadcasting to either scalar or fixed length.
574584 # This is to avoid confusion broadcasting vs vectorization rules.
575585 # This restriction could be lifted if we encounter a use-case.
576- # pyformat: disable
577586 if (
578587 final_shape is None
579588 or len (shape_lengths ) > 2
580589 or (len (shape_lengths ) == 2 and 0 not in shape_lengths )
581590 ):
582- # pyformat: enable
583591 raise ValueError (
584592 f'Conflicting batch shapes: { shape_to_names } . '
585593 f'Currently { type (self ).__qualname__ } .__init__ broadcasting is '
586594 'restricted to scalar or dim=1 . '
587- 'Please open an issue if you need more fine-grained broadcasting.' )
595+ 'Please open an issue if you need more fine-grained broadcasting.'
596+ )
588597
589598 def _broadcast_field (f : _ArrayField ) -> None :
590599 if f .host_shape == final_shape : # Already broadcasted
@@ -593,7 +602,8 @@ def _broadcast_field(f: _ArrayField) -> None:
593602 raise ValueError (
594603 f'{ type (self ).__qualname__ } has `broadcast=False`. '
595604 f'Cannot broadcast { f .name } from { f .full_shape } to { final_shape } . '
596- f'To enable broadcast, use `@dca.dataclass_array(broadcast=True)`.' )
605+ 'To enable broadcast, use `@dca.dataclass_array(broadcast=True)`.'
606+ )
597607 self ._setattr (f .name , f .broadcast_to (final_shape ))
598608
599609 self ._map_field (
@@ -670,7 +680,8 @@ def tree_unflatten(
670680 zip (
671681 metadata .array_field_names ,
672682 array_field_values ,
673- ))
683+ )
684+ )
674685 init_fields = {}
675686 non_init_fields = {}
676687 fields = {f .name : f for f in dataclasses .fields (cls )}
@@ -687,7 +698,8 @@ def tree_unflatten(
687698 raise ValueError (
688699 '`dca.DataclassArray` field with init=False should be explicitly '
689700 'specified in `__dca_non_init_fields__` for them to be '
690- 'propagated by `tree_map`.' )
701+ 'propagated by `tree_map`.'
702+ )
691703 # TODO(py310): Delete once dataclass supports `kw_only=True`
692704 for k , v in non_init_fields .items ():
693705 self ._setattr (k , v ) # pylint: disable=protected-access
@@ -703,7 +715,8 @@ def assert_same_xnp(self, x: Union[Array[...], DataclassArray]) -> None:
703715 if xnp is not self .xnp :
704716 raise ValueError (
705717 f'{ self .__class__ .__name__ } is { self .xnp .__name__ } but got input '
706- f'{ xnp .__name__ } . Please cast input first.' )
718+ f'{ xnp .__name__ } . Please cast input first.'
719+ )
707720
708721
709722def _infer_xnp (xnps : dict [enp .NpModule , list [str ]]) -> enp .NpModule :
@@ -741,20 +754,23 @@ def _to_absolute_indices(indices: _Indices, *, shape: Shape) -> _Indices:
741754 raise IndexError ("an index can only have a single ellipsis ('...')" )
742755 valid_count = _count_not_none (indices )
743756 if valid_count > len (shape ):
744- raise IndexError (f'too many indices for array. Batch shape is { shape } , but '
745- f'rank-{ valid_count } was provided.' )
757+ raise IndexError (
758+ f'too many indices for array. Batch shape is { shape } , but '
759+ f'rank-{ valid_count } was provided.'
760+ )
746761 if not ellipsis_count :
747762 return indices
748763 ellipsis_index = indices .index (Ellipsis )
749764 start_elems = indices [:ellipsis_index ]
750- end_elems = indices [ellipsis_index + 1 :]
765+ end_elems = indices [ellipsis_index + 1 :]
751766 ellipsis_replacement = [slice (None )] * (len (shape ) - valid_count )
752767 return (* start_elems , * ellipsis_replacement , * end_elems )
753768
754769
755770@dataclasses .dataclass (frozen = True )
756771class _TreeMetadata :
757772 """Metadata forwarded in ``."""
773+
758774 array_field_names : list [str ]
759775 non_array_field_kwargs : dict [str , Any ]
760776
@@ -800,6 +816,7 @@ class _ArrayFieldMetadata:
800816 dtype: Type of the array. Can be `array_types.dtypes.DType` or
801817 `dca.DataclassArray` for nested arrays.
802818 """
819+
803820 inner_shape_non_static : DynamicShape
804821 dtype : Union [array_types .dtypes .DType , type [DataclassArray ]]
805822
@@ -832,6 +849,7 @@ class _ArrayField(_ArrayFieldMetadata, Generic[DcOrArrayT]):
832849 name: Instance of the attribute
833850 host: Dataclass instance who this field is attached too
834851 """
852+
835853 name : str
836854 host : DataclassArray = dataclasses .field (repr = False )
837855
@@ -862,12 +880,13 @@ def inner_shape(self) -> Shape:
862880 """Returns the the static shape resolved for the current value."""
863881 if not self .inner_shape_non_static :
864882 return ()
865- static_shape = self .full_shape [- len (self .inner_shape_non_static ):]
883+ static_shape = self .full_shape [- len (self .inner_shape_non_static ) :]
866884
867885 def err_msg () -> ValueError :
868886 return ValueError (
869887 f'Shape do not match. Expected: { self .inner_shape_non_static } . '
870- f'Got { static_shape } ' )
888+ f'Got { static_shape } '
889+ )
871890
872891 if len (static_shape ) != len (self .inner_shape_non_static ):
873892 raise err_msg ()
@@ -889,9 +908,9 @@ def is_value_missing(self) -> bool:
889908 # In `jax/_src/api_util.py` for `flatten_axes`, jax set all values to a
890909 # dummy sentinel `object()` value.
891910 return True
892- elif (isinstance ( self . value , DataclassArray ) and
893- not self .value ._array_fields # pylint: disable=protected-access
894- ):
911+ elif (
912+ isinstance ( self . value , DataclassArray ) and not self .value ._array_fields # pylint: disable=protected-access
913+ ):
895914 # Nested dataclass case (if all attributes are `None`, so no active
896915 # array fields)
897916 return True
@@ -903,14 +922,15 @@ def host_shape(self) -> Shape:
903922 if not self .inner_shape_non_static :
904923 shape = self .full_shape
905924 else :
906- shape = self .full_shape [:- len (self .inner_shape_non_static )]
925+ shape = self .full_shape [: - len (self .inner_shape_non_static )]
907926 return shape
908927
909928 def assert_shape (self ) -> None :
910929 if self .host_shape + self .inner_shape != self .full_shape :
911930 raise ValueError (
912- f'Shape should be '
913- f'{ (py_utils .Ellipsis , * self .inner_shape )} . Got: { self .full_shape } ' )
931+ 'Shape should be '
932+ f'{ (py_utils .Ellipsis , * self .inner_shape )} . Got: { self .full_shape } '
933+ )
914934
915935 def broadcast_to (self , shape : Shape ) -> DcOrArrayT :
916936 """Broadcast the host_shape."""
0 commit comments