@@ -56,14 +56,15 @@ def masked_array_error(*args, **kwargs):
5656 "Use arr.filled() to convert the value to a standard numpy array." )
5757
5858core .pytype_aval_mappings [np .ma .MaskedArray ] = masked_array_error
59+ core .xla_pytype_aval_mappings [np .ma .MaskedArray ] = masked_array_error
5960
6061
6162def _make_shaped_array_for_numpy_array (x : np .ndarray ) -> ShapedArray :
6263 dtype = x .dtype
6364 dtypes .check_valid_dtype (dtype )
6465 return ShapedArray (x .shape , dtypes .canonicalize_dtype (dtype ))
6566
66- core .pytype_aval_mappings [np .ndarray ] = canonical_concrete_aval
67+ core .pytype_aval_mappings [np .ndarray ] = _make_shaped_array_for_numpy_array
6768core .xla_pytype_aval_mappings [np .ndarray ] = _make_shaped_array_for_numpy_array
6869
6970
@@ -73,26 +74,20 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
7374 return ShapedArray (np .shape (x ), dtypes .canonicalize_dtype (dtype ))
7475
7576for t in numpy_scalar_types :
76- core .pytype_aval_mappings [t ] = canonical_concrete_aval
77+ core .pytype_aval_mappings [t ] = _make_shaped_array_for_numpy_scalar
7778 core .xla_pytype_aval_mappings [t ] = _make_shaped_array_for_numpy_scalar
7879
7980core .literalable_types .update (array_types )
8081
8182
82- def _make_concrete_python_scalar (t , x ):
83- dtype = dtypes ._scalar_type_to_dtype (t , x )
84- weak_type = dtypes .is_weakly_typed (x )
85- return canonical_concrete_aval (np .array (x , dtype = dtype ), weak_type = weak_type )
86-
87-
8883def _make_abstract_python_scalar (typ , val ):
8984 # Note: all python scalar types are weak except bool, because bool only
9085 # comes in a single width.
9186 return ShapedArray ((), dtypes ._scalar_type_to_dtype (typ , val ),
9287 weak_type = typ is not bool )
9388
9489for t in dtypes .python_scalar_dtypes :
95- core .pytype_aval_mappings [t ] = partial (_make_concrete_python_scalar , t )
90+ core .pytype_aval_mappings [t ] = partial (_make_abstract_python_scalar , t )
9691 core .xla_pytype_aval_mappings [t ] = partial (_make_abstract_python_scalar , t )
9792
9893core .literalable_types .update (dtypes .python_scalar_dtypes .keys ())
0 commit comments