How to use jax.numpy.linalg.eigvals on GPU? #8350
-
Consider the following code: import jax.numpy as jnp
I = jnp.eye(4)
jnp.linalg.eigvals(I) It works on CPU but fails with the following error on GPU:
The documentation doesn't provide any details about how to indicate that the argument to the What am I doing wrong? Or is that the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Indeed, |
Beta Was this translation helpful? Give feedback.
Indeed,
eig
andeigvals
are not implemented on GPU, but I believeeigh
andeigvalsh
are (theh
indicates that these routines are designed for hermitian matrices, and for real-valued inputs hermitian implies symmetric). If you are doing symmetric eigenvalue decomposition on any backend, you should prefereigh
oreigvalsh
.