Skip to content

Commit c9afc89

Browse files
committed
Always use the same code for array avals
1 parent 05ad393 commit c9afc89

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

jax/_src/abstract_arrays.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,6 @@
4343

4444
array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic
4545

46-
def canonical_concrete_aval(val, weak_type=None):
47-
weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type
48-
dtype = dtypes.canonicalize_dtype(np.result_type(val))
49-
dtypes.check_valid_dtype(dtype)
50-
sharding = core._get_abstract_sharding(val)
51-
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)
52-
5346

5447
def masked_array_error(*args, **kwargs):
5548
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "

jax/_src/array.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import operator as op
2323
from typing import Any, TYPE_CHECKING, cast
2424

25-
from jax._src import abstract_arrays
2625
from jax._src import api
2726
from jax._src import api_util
2827
from jax._src import basearray
@@ -1027,18 +1026,20 @@ def make_array_from_single_device_arrays(
10271026
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
10281027
committed=True)
10291028

1030-
1031-
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
1032-
core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
10331029
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
1030+
10341031
def _get_aval_array(self):
10351032
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
10361033
return self.aval.update(sharding=NamedSharding(
10371034
self.sharding.mesh.abstract_mesh,
10381035
self.sharding.spec._normalized_spec(self.ndim)))
10391036
else:
10401037
return self.aval
1038+
10411039
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
1040+
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
1041+
core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array
1042+
10421043
# TODO(jakevdp) replace this with true inheritance at the C++ level.
10431044
basearray.Array.register(ArrayImpl)
10441045

0 commit comments

Comments
 (0)