8181 TypeAlias ,
8282 TypeVar ,
8383 cast ,
84+ get_origin ,
8485 overload ,
8586)
8687
9495if TYPE_CHECKING :
9596 from numpy .typing import DTypeLike
9697
98+ from pymbolic .typing import Integer
99+
97100
98101# deprecated, use ScalarLike instead
99102Scalar : TypeAlias = _Scalar
@@ -237,7 +240,7 @@ def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
237240
238241
239242ArrayT = TypeVar ("ArrayT" , bound = Array )
240- ArrayOrScalar : TypeAlias = Array | _Scalar
243+ ArrayOrScalar : TypeAlias = Array | ScalarLike
241244ArrayOrScalarT = TypeVar ("ArrayOrScalarT" , bound = ArrayOrScalar )
242245ArrayOrContainer : TypeAlias = Array | ArrayContainer
243246ArrayOrArithContainer : TypeAlias = Array | ArithArrayContainer
@@ -261,6 +264,22 @@ def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
261264NumpyOrContainerOrScalar : TypeAlias = "np.ndarray | ArrayContainer | ScalarLike"
262265
263266
267+ def is_scalar_type (tp : object , / ) -> bool :
268+ if not isinstance (tp , type ):
269+ tp = get_origin (tp )
270+ if not isinstance (tp , type ):
271+ return False
272+ if tp is int or tp is bool :
273+ # int has loads of undesirable subclasses: enums, ...
274+ # We're not going to tolerate them.
275+ #
276+ # bool has to be OK because arraycontext is expected to handle
277+ # arrays of bools.
278+ return True
279+
280+ return issubclass (tp , (np .generic , float , complex ))
281+
282+
264283def is_scalar_like (x : object , / ) -> TypeIs [Scalar ]:
265284 return np .isscalar (x )
266285
0 commit comments