Skip to content

Commit 0140a98

Browse files
committed
Improve trace-time performance of jnp.isscalar
1 parent 9e5edb7 commit 0140a98

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool:
624624
>>> jnp.isscalar(slice(10))
625625
False
626626
"""
627-
if (isinstance(element, (np.ndarray, jax.Array))
628-
or hasattr(element, '__jax_array__')
629-
or np.isscalar(element)):
627+
if np.isscalar(element):
628+
return True
629+
elif isinstance(element, (np.ndarray, jax.Array)):
630+
return element.ndim == 0
631+
elif hasattr(element, '__jax_array__'):
630632
return asarray(element).ndim == 0
631633
return False
632634

0 commit comments

Comments
 (0)