@@ -348,7 +348,7 @@ def reshape(self: _DcT, shape: Union[Shape, str], **axes_length: int) -> _DcT:
348348 The dataclass array with the new shape
349349 """
350350 if isinstance (shape , str ): # Einops support
351- return self ._map_field (
351+ return self ._map_field ( # pylint: disable=protected-access
352352 array_fn = lambda f : einops .rearrange ( # pylint: disable=g-long-lambda
353353 f .value ,
354354 np_utils .to_absolute_einops (shape , nlastdim = len (f .inner_shape )),
@@ -365,15 +365,15 @@ def reshape(self: _DcT, shape: Union[Shape, str], **axes_length: int) -> _DcT:
365365 def _reshape (f : _ArrayField ):
366366 return f .value .reshape (shape + f .inner_shape )
367367
368- return self ._map_field (array_fn = _reshape , dc_fn = _reshape )
368+ return self ._map_field (array_fn = _reshape , dc_fn = _reshape ) # pylint: disable=protected-access
369369
370370 def flatten (self : _DcT ) -> _DcT :
371371 """Flatten the batch shape."""
372372 return self .reshape ((- 1 ,))
373373
374374 def broadcast_to (self : _DcT , shape : Shape ) -> _DcT :
375375 """Broadcast the batch shape."""
376- return self ._map_field (
376+ return self ._map_field ( # pylint: disable=protected-access
377377 array_fn = lambda f : f .broadcast_to (shape ),
378378 dc_fn = lambda f : f .broadcast_to (shape ),
379379 )
@@ -456,7 +456,7 @@ def map_field(
456456 fn : Callable [[Array ['*din' ]], Array ['*dout' ]],
457457 ) -> _DcT :
458458 """Apply a transformation on all arrays from the fields."""
459- return self ._map_field (
459+ return self ._map_field ( # pylint: disable=protected-access
460460 array_fn = lambda f : fn (f .value ),
461461 dc_fn = lambda f : f .value .map_field (fn ),
462462 )
@@ -530,7 +530,7 @@ def _as_torch(f):
530530 array_fn = lambda f : xnp .asarray (f .value )
531531
532532 # Update all childs
533- new_self = self ._map_field (
533+ new_self = self ._map_field ( # pylint: disable=protected-access
534534 array_fn = array_fn ,
535535 dc_fn = lambda f : f .value .as_xnp (xnp ),
536536 )
@@ -593,7 +593,9 @@ def _all_fields_empty(self) -> bool:
593593 # `tf.nest` sometimes replace values by dummy `.` inside
594594 # `assert_same_structure`
595595 if enp .lazy .has_tf :
596- from tensorflow .python .util import nest_util # pytype: disable=import-error # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
596+ # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
597+ from tensorflow .python .util import nest_util # pytype: disable=import-error
598+ # pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
597599
598600 if any (f .value is nest_util ._DOT for f in self ._array_fields ): # pylint: disable=protected-access,not-an-iterable
599601 return True
@@ -890,7 +892,7 @@ def _init_cls(self: DataclassArray) -> None:
890892 # DataclassArray without any array fields
891893 # Hack: To support `.xnp`, `.shape`, we add a dummy empty field which
892894 # is propagated by the various ops.
893- dca_fields_metadata [_DUMMY_ARRAY_FIELD ] = _ArrayFieldMetadata (
895+ dca_fields_metadata [_DUMMY_ARRAY_FIELD ] = _ArrayFieldMetadata ( # pytype: disable=wrong-arg-types
894896 inner_shape_non_static = (),
895897 dtype = np .float32 ,
896898 )
@@ -1006,12 +1008,12 @@ class _ArrayFieldMetadata:
10061008 Attributes:
10071009 inner_shape_non_static: Inner shape. Can contain non-static dims (e.g.
10081010 `(None, 3)`)
1009- dtype: Type of the array. Can be `array_types .dtypes.DType` or
1011+ dtype: Type of the array. Can be `enp .dtypes.DType` or
10101012 `dca.DataclassArray` for nested arrays.
10111013 """
10121014
10131015 inner_shape_non_static : DynamicShape
1014- dtype : Union [array_types .dtypes .DType , Type [DataclassArray ]]
1016+ dtype : Union [enp .dtypes .DType , Type [DataclassArray ]]
10151017
10161018 def __post_init__ (self ):
10171019 """Normalizing/validating the shape/dtype."""
@@ -1020,7 +1022,7 @@ def __post_init__(self):
10201022
10211023 # Validate/normalize the dtype
10221024 if not self .is_dataclass :
1023- self .dtype = array_types .dtypes .DType .from_value (self .dtype )
1025+ self .dtype = enp .dtypes .DType .from_value (self .dtype )
10241026 # TODO(epot): Filter invalid dtypes, like `str` ?
10251027
10261028 def to_dict (self ) -> dict [str , Any ]:
0 commit comments