Skip to content

Commit c51497d

Browse files
committed
is_scalar_type: rule out int subclasses
1 parent aaf7206 commit c51497d

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

arraycontext/container/dataclass.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
ArrayContainer,
6262
ArrayOrContainer,
6363
ArrayOrContainerOrScalar,
64+
is_scalar_type,
6465
)
6566

6667

@@ -95,15 +96,6 @@ def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool:
9596
return tp is Array or is_array_container_type(tp)
9697

9798

98-
def is_scalar_type(tp: object, /) -> bool:
99-
if not isinstance(tp, type):
100-
tp = get_origin(tp)
101-
if not isinstance(tp, type):
102-
return False
103-
104-
return issubclass(tp, (np.generic, int, float, complex))
105-
106-
10799
def dataclass_array_container(cls: type[T]) -> type[T]:
108100
"""A class decorator that makes the class to which it is applied an
109101
:class:`ArrayContainer` by registering appropriate implementations of

arraycontext/typing.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
TypeAlias,
8282
TypeVar,
8383
cast,
84+
get_origin,
8485
overload,
8586
)
8687

@@ -94,6 +95,8 @@
9495
if TYPE_CHECKING:
9596
from numpy.typing import DTypeLike
9697

98+
from pymbolic.typing import Integer
99+
97100

98101
# deprecated, use ScalarLike instead
99102
Scalar: TypeAlias = _Scalar
@@ -237,7 +240,7 @@ def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
237240

238241

239242
ArrayT = TypeVar("ArrayT", bound=Array)
240-
ArrayOrScalar: TypeAlias = Array | _Scalar
243+
ArrayOrScalar: TypeAlias = Array | ScalarLike
241244
ArrayOrScalarT = TypeVar("ArrayOrScalarT", bound=ArrayOrScalar)
242245
ArrayOrContainer: TypeAlias = Array | ArrayContainer
243246
ArrayOrArithContainer: TypeAlias = Array | ArithArrayContainer
@@ -261,6 +264,22 @@ def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
261264
NumpyOrContainerOrScalar: TypeAlias = "np.ndarray | ArrayContainer | ScalarLike"
262265

263266

267+
def is_scalar_type(tp: object, /) -> bool:
268+
if not isinstance(tp, type):
269+
tp = get_origin(tp)
270+
if not isinstance(tp, type):
271+
return False
272+
if tp is int or tp is bool:
273+
# int has loads of undesirable subclasses: enums, ...
274+
# We're not going to tolerate them.
275+
#
276+
# bool has to be OK because arraycontext is expected to handle
277+
# arrays of bools.
278+
return True
279+
280+
return issubclass(tp, (np.generic, float, complex))
281+
282+
264283
def is_scalar_like(x: object, /) -> TypeIs[Scalar]:
265284
return np.isscalar(x)
266285

0 commit comments

Comments
 (0)