Skip to content

Commit 05ad393

Browse files
Merge pull request #25534 from jakevdp:faster-avals
PiperOrigin-RevId: 707228671
2 parents 71b23ea + a2ac234 commit 05ad393

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

jax/_src/abstract_arrays.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,15 @@ def masked_array_error(*args, **kwargs):
5656
"Use arr.filled() to convert the value to a standard numpy array.")
5757

5858
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
59+
core.xla_pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
5960

6061

6162
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
6263
dtype = x.dtype
6364
dtypes.check_valid_dtype(dtype)
6465
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
6566

66-
core.pytype_aval_mappings[np.ndarray] = canonical_concrete_aval
67+
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
6768
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
6869

6970

@@ -73,26 +74,20 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
7374
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
7475

7576
for t in numpy_scalar_types:
76-
core.pytype_aval_mappings[t] = canonical_concrete_aval
77+
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
7778
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
7879

7980
core.literalable_types.update(array_types)
8081

8182

82-
def _make_concrete_python_scalar(t, x):
83-
dtype = dtypes._scalar_type_to_dtype(t, x)
84-
weak_type = dtypes.is_weakly_typed(x)
85-
return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type)
86-
87-
8883
def _make_abstract_python_scalar(typ, val):
8984
# Note: all python scalar types are weak except bool, because bool only
9085
# comes in a single width.
9186
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
9287
weak_type=typ is not bool)
9388

9489
for t in dtypes.python_scalar_dtypes:
95-
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
90+
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
9691
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
9792

9893
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

0 commit comments

Comments
 (0)