Replies: 1 comment 4 replies
-
You could use import jax
import jax.numpy as jnp
import scipy.linalg
def schur(x):
return jax.pure_callback(scipy.linalg.schur, (x, x), x)
@jax.jit
def f(x):
return schur(x)
print(f(jnp.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]], jnp.float32))) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
It is known that when using a GPU or TPU backend, calling
jax.scipy.linalg.schur
raisesNotImplementedError: Schur decomposition is only implemented on the CPU backend
. I wonder if it is possible to bypass the issue by falling back to the CPU implementation in such cases.I tried to transfer the array back to the CPU with
jax.device_get(X)
, hoping that would trigger the LAPACK backend, but it still gives the sameNotImplementedError
.To clarify, I only want this particular function to be executed on CPU, since there isn't a GPU/TPU implementation. The idea is that I will do a lot of computation on a GPU/TPU, transfer the array to the CPU, perform Schur decomposition, transfer the result back to GPT/TPU, and continue execution there.
Beta Was this translation helpful? Give feedback.
All reactions