Skip to content

Commit 2463bf1

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Avoid repeatedly rebuilding a tuple in issubdtype.
PiperOrigin-RevId: 696304594
1 parent 426e13a commit 2463bf1

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

jax/_src/dtypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,8 @@ def _issubclass(a: Any, b: Any) -> bool:
339339
return False
340340

341341

342+
_types_for_issubdtype = (type, np.dtype, ExtendedDType)
343+
342344
# TODO(jakevdp): consider whether to disallow None here. We allow it
343345
# because np.issubdtype allows it (and treats it as equivalent to float64).
344346
def issubdtype(a: DTypeLike | ExtendedDType | None,
@@ -360,8 +362,8 @@ def issubdtype(a: DTypeLike | ExtendedDType | None,
360362
# unhashable (e.g. custom objects with a dtype attribute). The following check is
361363
# fast and covers the majority of calls to this function within JAX library code.
362364
return _issubdtype_cached(
363-
a if isinstance(a, (type, np.dtype, ExtendedDType)) else np.dtype(a), # type: ignore[arg-type]
364-
b if isinstance(b, (type, np.dtype, ExtendedDType)) else np.dtype(b), # type: ignore[arg-type]
365+
a if isinstance(a, _types_for_issubdtype) else np.dtype(a), # type: ignore[arg-type]
366+
b if isinstance(b, _types_for_issubdtype) else np.dtype(b), # type: ignore[arg-type]
365367
)
366368

367369

0 commit comments

Comments
 (0)