Skip to content

Commit 4052fc0

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Update format to black/midnight
PiperOrigin-RevId: 478719306
1 parent 4b1d843 commit 4052fc0

18 files changed

+189
-140
lines changed

dataclass_array/array_dataclass.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class DataclassParams:
6161
cast_dtype: If `True`, auto-cast inputs `dtype`
6262
cast_list: If `True`, auto-cast lists to `xnp.ndarray`
6363
"""
64+
6465
# If modifying this, make sure to modify `@dataclass_array` too!
6566
broadcast: bool = False
6667
cast_dtype: bool = False
@@ -101,7 +102,8 @@ def decorator(cls):
101102
if not issubclass(cls, DataclassArray):
102103
raise TypeError(
103104
'`@dca.dataclass_array` can only be applied on `dca.DataclassArray`. '
104-
f'Got: {cls}')
105+
f'Got: {cls}'
106+
)
105107
cls.__dca_params__ = DataclassParams(
106108
broadcast=broadcast,
107109
cast_dtype=cast_dtype,
@@ -174,6 +176,7 @@ class Square(DataclassArray):
174176
field annotated with `field: np.ndarray` or similar).
175177
176178
"""
179+
177180
# Child class inherit the default params by default, but can also
178181
# overwrite them.
179182
__dca_params__: ClassVar[DataclassParams] = DataclassParams()
@@ -208,7 +211,8 @@ def __post_init__(self) -> None:
208211
# Make sure the dataclass was registered and frozen
209212
if not dataclasses.is_dataclass(cls) or not cls.__dataclass_params__.frozen: # pytype: disable=attribute-error
210213
raise ValueError(
211-
'`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)')
214+
'`dca.DataclassArray` need to be @dataclasses.dataclass(frozen=True)'
215+
)
212216

213217
# Register the tree_map here instead of `__init_subclass__` as `jax` may
214218
# not have been registered yet during import
@@ -219,7 +223,8 @@ def __post_init__(self) -> None:
219223
if not self._all_array_fields:
220224
raise ValueError(
221225
f'{self.__class__.__qualname__} should have at least one '
222-
'`dca.array_field`')
226+
'`dca.array_field`'
227+
)
223228

224229
# Validate and normalize array fields
225230
# * Maybe cast (list, np) -> xnp
@@ -309,12 +314,10 @@ def flatten(self: _DcT) -> _DcT:
309314

310315
def broadcast_to(self: _DcT, shape: Shape) -> _DcT:
311316
"""Broadcast the batch shape."""
312-
# pyformat: disable
313317
return self._map_field(
314318
array_fn=lambda f: f.broadcast_to(shape),
315319
dc_fn=lambda f: f.broadcast_to(shape),
316320
)
317-
# pyformat: enable
318321

319322
def __getitem__(self: _DcT, indices: _IndicesArg) -> _DcT:
320323
"""Slice indexing."""
@@ -342,7 +345,8 @@ def __len__(self) -> int:
342345
"""Length of the first array dimension."""
343346
if not self.shape:
344347
raise TypeError(
345-
f'len() of unsized {self.__class__.__name__} (shape={self.shape})')
348+
f'len() of unsized {self.__class__.__name__} (shape={self.shape})'
349+
)
346350
return self.shape[0]
347351

348352
def __bool__(self) -> Literal[True]:
@@ -384,7 +388,8 @@ def fn(ray: Optional[dca.Ray] = None):
384388
if self.shape and not len(self): # pylint: disable=g-explicit-length-test
385389
raise ValueError(
386390
f'The truth value of {self.__class__.__name__} when `len(x) == 0` '
387-
'is ambigous. Use `len(x)` or `x is not None`.')
391+
'is ambigous. Use `len(x)` or `x is not None`.'
392+
)
388393
return True
389394

390395
def map_field(
@@ -475,8 +480,10 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
475480
try:
476481
hints = typing_extensions.get_type_hints(cls, include_extras=True)
477482
except Exception as e: # pylint: disable=broad-except
478-
msg = (f'Could not infer typing annotation of {cls.__qualname__} '
479-
f'defined in {cls.__module__}:\n')
483+
msg = (
484+
f'Could not infer typing annotation of {cls.__qualname__} '
485+
f'defined in {cls.__module__}:\n'
486+
)
480487
lines = [f' * {k}: {v!r}' for k, v in cls.__annotations__.items()]
481488
lines = '\n'.join(lines)
482489

@@ -486,16 +493,19 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
486493
f.name: _make_field_metadata(f, hints)
487494
for f in dataclasses.fields(self)
488495
}
489-
cls._dca_fields_metadata = { # Filter `None` values # pylint: disable=protected-access
490-
k: v for k, v in dca_fields_metadata.items() if v is not None
491-
}
496+
cls._dca_fields_metadata = ( # pylint: disable=protected-access
497+
{ # Filter `None` values
498+
k: v for k, v in dca_fields_metadata.items() if v is not None
499+
}
500+
)
492501

493502
return { # pylint: disable=g-complex-comprehension
494503
name: _ArrayField(
495504
name=name,
496505
host=self,
497506
**field_metadata.to_dict(), # pylint: disable=not-a-mapping
498-
) for name, field_metadata in cls._dca_fields_metadata.items() # pylint: disable=protected-access
507+
)
508+
for name, field_metadata in cls._dca_fields_metadata.items() # pylint: disable=protected-access
499509
}
500510

501511
@epy.cached_property
@@ -573,18 +583,17 @@ def _broadcast_shape_inplace(self) -> Optional[Shape]:
573583
# Currently, we restrict broadcasting to either scalar or fixed length.
574584
# This is to avoid confusion broadcasting vs vectorization rules.
575585
# This restriction could be lifted if we encounter a use-case.
576-
# pyformat: disable
577586
if (
578587
final_shape is None
579588
or len(shape_lengths) > 2
580589
or (len(shape_lengths) == 2 and 0 not in shape_lengths)
581590
):
582-
# pyformat: enable
583591
raise ValueError(
584592
f'Conflicting batch shapes: {shape_to_names}. '
585593
f'Currently {type(self).__qualname__}.__init__ broadcasting is '
586594
'restricted to scalar or dim=1 . '
587-
'Please open an issue if you need more fine-grained broadcasting.')
595+
'Please open an issue if you need more fine-grained broadcasting.'
596+
)
588597

589598
def _broadcast_field(f: _ArrayField) -> None:
590599
if f.host_shape == final_shape: # Already broadcasted
@@ -593,7 +602,8 @@ def _broadcast_field(f: _ArrayField) -> None:
593602
raise ValueError(
594603
f'{type(self).__qualname__} has `broadcast=False`. '
595604
f'Cannot broadcast {f.name} from {f.full_shape} to {final_shape}. '
596-
f'To enable broadcast, use `@dca.dataclass_array(broadcast=True)`.')
605+
'To enable broadcast, use `@dca.dataclass_array(broadcast=True)`.'
606+
)
597607
self._setattr(f.name, f.broadcast_to(final_shape))
598608

599609
self._map_field(
@@ -670,7 +680,8 @@ def tree_unflatten(
670680
zip(
671681
metadata.array_field_names,
672682
array_field_values,
673-
))
683+
)
684+
)
674685
init_fields = {}
675686
non_init_fields = {}
676687
fields = {f.name: f for f in dataclasses.fields(cls)}
@@ -687,7 +698,8 @@ def tree_unflatten(
687698
raise ValueError(
688699
'`dca.DataclassArray` field with init=False should be explicitly '
689700
'specified in `__dca_non_init_fields__` for them to be '
690-
'propagated by `tree_map`.')
701+
'propagated by `tree_map`.'
702+
)
691703
# TODO(py310): Delete once dataclass supports `kw_only=True`
692704
for k, v in non_init_fields.items():
693705
self._setattr(k, v) # pylint: disable=protected-access
@@ -703,7 +715,8 @@ def assert_same_xnp(self, x: Union[Array[...], DataclassArray]) -> None:
703715
if xnp is not self.xnp:
704716
raise ValueError(
705717
f'{self.__class__.__name__} is {self.xnp.__name__} but got input '
706-
f'{xnp.__name__}. Please cast input first.')
718+
f'{xnp.__name__}. Please cast input first.'
719+
)
707720

708721

709722
def _infer_xnp(xnps: dict[enp.NpModule, list[str]]) -> enp.NpModule:
@@ -741,20 +754,23 @@ def _to_absolute_indices(indices: _Indices, *, shape: Shape) -> _Indices:
741754
raise IndexError("an index can only have a single ellipsis ('...')")
742755
valid_count = _count_not_none(indices)
743756
if valid_count > len(shape):
744-
raise IndexError(f'too many indices for array. Batch shape is {shape}, but '
745-
f'rank-{valid_count} was provided.')
757+
raise IndexError(
758+
f'too many indices for array. Batch shape is {shape}, but '
759+
f'rank-{valid_count} was provided.'
760+
)
746761
if not ellipsis_count:
747762
return indices
748763
ellipsis_index = indices.index(Ellipsis)
749764
start_elems = indices[:ellipsis_index]
750-
end_elems = indices[ellipsis_index + 1:]
765+
end_elems = indices[ellipsis_index + 1 :]
751766
ellipsis_replacement = [slice(None)] * (len(shape) - valid_count)
752767
return (*start_elems, *ellipsis_replacement, *end_elems)
753768

754769

755770
@dataclasses.dataclass(frozen=True)
756771
class _TreeMetadata:
757772
"""Metadata forwarded in ``."""
773+
758774
array_field_names: list[str]
759775
non_array_field_kwargs: dict[str, Any]
760776

@@ -800,6 +816,7 @@ class _ArrayFieldMetadata:
800816
dtype: Type of the array. Can be `array_types.dtypes.DType` or
801817
`dca.DataclassArray` for nested arrays.
802818
"""
819+
803820
inner_shape_non_static: DynamicShape
804821
dtype: Union[array_types.dtypes.DType, type[DataclassArray]]
805822

@@ -832,6 +849,7 @@ class _ArrayField(_ArrayFieldMetadata, Generic[DcOrArrayT]):
832849
name: Instance of the attribute
833850
host: Dataclass instance who this field is attached too
834851
"""
852+
835853
name: str
836854
host: DataclassArray = dataclasses.field(repr=False)
837855

@@ -862,12 +880,13 @@ def inner_shape(self) -> Shape:
862880
"""Returns the the static shape resolved for the current value."""
863881
if not self.inner_shape_non_static:
864882
return ()
865-
static_shape = self.full_shape[-len(self.inner_shape_non_static):]
883+
static_shape = self.full_shape[-len(self.inner_shape_non_static) :]
866884

867885
def err_msg() -> ValueError:
868886
return ValueError(
869887
f'Shape do not match. Expected: {self.inner_shape_non_static}. '
870-
f'Got {static_shape}')
888+
f'Got {static_shape}'
889+
)
871890

872891
if len(static_shape) != len(self.inner_shape_non_static):
873892
raise err_msg()
@@ -889,9 +908,9 @@ def is_value_missing(self) -> bool:
889908
# In `jax/_src/api_util.py` for `flatten_axes`, jax set all values to a
890909
# dummy sentinel `object()` value.
891910
return True
892-
elif (isinstance(self.value, DataclassArray) and
893-
not self.value._array_fields # pylint: disable=protected-access
894-
):
911+
elif (
912+
isinstance(self.value, DataclassArray) and not self.value._array_fields # pylint: disable=protected-access
913+
):
895914
# Nested dataclass case (if all attributes are `None`, so no active
896915
# array fields)
897916
return True
@@ -903,14 +922,15 @@ def host_shape(self) -> Shape:
903922
if not self.inner_shape_non_static:
904923
shape = self.full_shape
905924
else:
906-
shape = self.full_shape[:-len(self.inner_shape_non_static)]
925+
shape = self.full_shape[: -len(self.inner_shape_non_static)]
907926
return shape
908927

909928
def assert_shape(self) -> None:
910929
if self.host_shape + self.inner_shape != self.full_shape:
911930
raise ValueError(
912-
f'Shape should be '
913-
f'{(py_utils.Ellipsis, *self.inner_shape)}. Got: {self.full_shape}')
931+
'Shape should be '
932+
f'{(py_utils.Ellipsis, *self.inner_shape)}. Got: {self.full_shape}'
933+
)
914934

915935
def broadcast_to(self, shape: Shape) -> DcOrArrayT:
916936
"""Broadcast the host_shape."""

0 commit comments

Comments
 (0)