Skip to content

Commit f808582

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Remove deprecated aliases
PiperOrigin-RevId: 550854645
1 parent c961b2c commit f808582

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

dataclass_array/array_dataclass.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def reshape(self: _DcT, shape: Union[Shape, str], **axes_length: int) -> _DcT:
348348
The dataclass array with the new shape
349349
"""
350350
if isinstance(shape, str): # Einops support
351-
return self._map_field(
351+
return self._map_field( # pylint: disable=protected-access
352352
array_fn=lambda f: einops.rearrange( # pylint: disable=g-long-lambda
353353
f.value,
354354
np_utils.to_absolute_einops(shape, nlastdim=len(f.inner_shape)),
@@ -365,15 +365,15 @@ def reshape(self: _DcT, shape: Union[Shape, str], **axes_length: int) -> _DcT:
365365
def _reshape(f: _ArrayField):
366366
return f.value.reshape(shape + f.inner_shape)
367367

368-
return self._map_field(array_fn=_reshape, dc_fn=_reshape)
368+
return self._map_field(array_fn=_reshape, dc_fn=_reshape) # pylint: disable=protected-access
369369

370370
def flatten(self: _DcT) -> _DcT:
371371
"""Flatten the batch shape."""
372372
return self.reshape((-1,))
373373

374374
def broadcast_to(self: _DcT, shape: Shape) -> _DcT:
375375
"""Broadcast the batch shape."""
376-
return self._map_field(
376+
return self._map_field( # pylint: disable=protected-access
377377
array_fn=lambda f: f.broadcast_to(shape),
378378
dc_fn=lambda f: f.broadcast_to(shape),
379379
)
@@ -456,7 +456,7 @@ def map_field(
456456
fn: Callable[[Array['*din']], Array['*dout']],
457457
) -> _DcT:
458458
"""Apply a transformation on all arrays from the fields."""
459-
return self._map_field(
459+
return self._map_field( # pylint: disable=protected-access
460460
array_fn=lambda f: fn(f.value),
461461
dc_fn=lambda f: f.value.map_field(fn),
462462
)
@@ -530,7 +530,7 @@ def _as_torch(f):
530530
array_fn = lambda f: xnp.asarray(f.value)
531531

532532
# Update all childs
533-
new_self = self._map_field(
533+
new_self = self._map_field( # pylint: disable=protected-access
534534
array_fn=array_fn,
535535
dc_fn=lambda f: f.value.as_xnp(xnp),
536536
)
@@ -593,7 +593,9 @@ def _all_fields_empty(self) -> bool:
593593
# `tf.nest` sometimes replace values by dummy `.` inside
594594
# `assert_same_structure`
595595
if enp.lazy.has_tf:
596-
from tensorflow.python.util import nest_util # pytype: disable=import-error # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
596+
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
597+
from tensorflow.python.util import nest_util # pytype: disable=import-error
598+
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
597599

598600
if any(f.value is nest_util._DOT for f in self._array_fields): # pylint: disable=protected-access,not-an-iterable
599601
return True
@@ -890,7 +892,7 @@ def _init_cls(self: DataclassArray) -> None:
890892
# DataclassArray without any array fields
891893
# Hack: To support `.xnp`, `.shape`, we add a dummy empty field which
892894
# is propagated by the various ops.
893-
dca_fields_metadata[_DUMMY_ARRAY_FIELD] = _ArrayFieldMetadata(
895+
dca_fields_metadata[_DUMMY_ARRAY_FIELD] = _ArrayFieldMetadata( # pytype: disable=wrong-arg-types
894896
inner_shape_non_static=(),
895897
dtype=np.float32,
896898
)
@@ -1006,12 +1008,12 @@ class _ArrayFieldMetadata:
10061008
Attributes:
10071009
inner_shape_non_static: Inner shape. Can contain non-static dims (e.g.
10081010
`(None, 3)`)
1009-
dtype: Type of the array. Can be `array_types.dtypes.DType` or
1011+
dtype: Type of the array. Can be `enp.dtypes.DType` or
10101012
`dca.DataclassArray` for nested arrays.
10111013
"""
10121014

10131015
inner_shape_non_static: DynamicShape
1014-
dtype: Union[array_types.dtypes.DType, Type[DataclassArray]]
1016+
dtype: Union[enp.dtypes.DType, Type[DataclassArray]]
10151017

10161018
def __post_init__(self):
10171019
"""Normalizing/validating the shape/dtype."""
@@ -1020,7 +1022,7 @@ def __post_init__(self):
10201022

10211023
# Validate/normalize the dtype
10221024
if not self.is_dataclass:
1023-
self.dtype = array_types.dtypes.DType.from_value(self.dtype)
1025+
self.dtype = enp.dtypes.DType.from_value(self.dtype)
10241026
# TODO(epot): Filter invalid dtypes, like `str` ?
10251027

10261028
def to_dict(self) -> dict[str, Any]:

dataclass_array/utils/np_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919

2020
from __future__ import annotations
2121

22-
from typing import Any, Union, Optional
22+
from typing import Any, Optional, Union
2323

2424
from dataclass_array import array_dataclass
25-
from dataclass_array.typing import Axes, DcOrArrayT, DTypeArg, Shape # pylint: disable=g-multiple-import
26-
from etils import array_types
25+
from dataclass_array.typing import Axes, DTypeArg, DcOrArrayT, Shape # pylint: disable=g-multiple-import,g-importing-member
2726
from etils import enp
28-
from etils.array_types import Array, ArrayLike # pylint: disable=g-multiple-import
27+
from etils.array_types import Array, ArrayLike # pylint: disable=g-multiple-import,g-importing-member
2928

3029
# Maybe some of those could live in `enp` ?
3130

@@ -124,7 +123,7 @@ def asarray(
124123
return x.as_xnp(xnp)
125124

126125
# Handle ndarray
127-
dtype = array_types.dtypes.DType.from_value(dtype)
126+
dtype = enp.dtypes.DType.from_value(dtype)
128127
return dtype.asarray(x, xnp=xnp, casting='all' if cast_dtype else 'none')
129128

130129

0 commit comments

Comments
 (0)