We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents fd4b160 + 0140a98 commit 40122f7Copy full SHA for 40122f7
jax/_src/numpy/lax_numpy.py
@@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool:
624
>>> jnp.isscalar(slice(10))
625
False
626
"""
627
- if (isinstance(element, (np.ndarray, jax.Array))
628
- or hasattr(element, '__jax_array__')
629
- or np.isscalar(element)):
+ if np.isscalar(element):
+ return True
+ elif isinstance(element, (np.ndarray, jax.Array)):
630
+ return element.ndim == 0
631
+ elif hasattr(element, '__jax_array__'):
632
return asarray(element).ndim == 0
633
return False
634
0 commit comments