Is there a way to jit
sparse linear solvers (jax.scipy.sparse.linalg
) with A
expressed as an operator?
#9804
-
Hey all, I'm interested in solving sparse linear systems in JAX. The docs for the JAX sparse linear solver API (shared across all solvers in
In my application, providing a function that computes a matrix-vector product Here's a minimal failing example: import jax
import jax.numpy as np
from jax.scipy import sparse
from jax import random
key = random.PRNGKey(seed=0)
N = 3
A = random.uniform(key=key, shape=(N, N))
b = random.uniform(key=key, shape=(N,))
A_operator = lambda v: A @ v # Just an example; for the real application we compute this more efficiently
##### This works!
x, _ = sparse.linalg.gmres(
A=A_operator,
b=b
)
##### But, add a `jit` and it no longer works!
try:
x, _ = jax.jit(sparse.linalg.gmres)(
A=A_operator,
b=b
)
except TypeError as e:
print("Didn't work!")
print(e) The resulting error has to do with the input type to The relevant part of the error traceback is here:
So, the question: Is there a way to Miscellaneous details: Verified that this occurs on both Windows 11 (manually built, CUDA backend) and Ubuntu 20.04 (on WSL, with CUDA backend); both tests run with Python 3.9.7, CUDA 11.6, Jax 0.3.1 and Jaxlib 0.3.0. I can confirm that in addition to this not working with the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I think that's what So, in your example, if |
Beta Was this translation helpful? Give feedback.
-
Besides using from jax.tree_util import Partial
x, _ = jax.jit(sparse.linalg.gmres)(
A=Partial(A_operator),
b=b
) note |
Beta Was this translation helpful? Give feedback.
I think that's what
static_argnums
is for in thejax.jit
API: marking certain arguments (my rule of thumb: not JAX-array-coercible ones like strings, callables etc.) constant ("static") for a JIT-compiled function call. These will then need to stay the same for as long as possible to avoid cache misses, which trigger full recompiles of the function.So, in your example, if
A
is then
-th argument to thejax.scipy.sparse.linalg.gmres
function, callingjax.jit(gmres, static_argnums=(n,))
should fix this.