Skip to content

Commit 40122f7

Browse files
Merge pull request jax-ml#25237 from jakevdp:faster-isscalar
PiperOrigin-RevId: 702517550
2 parents fd4b160 + 0140a98 commit 40122f7

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool:
624624
>>> jnp.isscalar(slice(10))
625625
False
626626
"""
627-
if (isinstance(element, (np.ndarray, jax.Array))
628-
or hasattr(element, '__jax_array__')
629-
or np.isscalar(element)):
627+
if np.isscalar(element):
628+
return True
629+
elif isinstance(element, (np.ndarray, jax.Array)):
630+
return element.ndim == 0
631+
elif hasattr(element, '__jax_array__'):
630632
return asarray(element).ndim == 0
631633
return False
632634

0 commit comments

Comments
 (0)