Efficient Lax Triangular Solve For A Large Matrix #6988
Unanswered
adam-hartshorne
asked this question in
Q&A
Replies: 1 comment 2 replies
-
If I understand your question correctly, I think you can do what you want with import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.linalg import solve_triangular
N = 10
M = 100
D = 5
key = random.PRNGKey(1701)
key1, key2 = random.split(key)
L = jnp.triu(random.uniform(key1, (N, N)))
b = random.uniform(key2, (M, N, D))
# slow version: call solve_triangular once per entry in b & stack the result
out_slow = jnp.stack([
solve_triangular(L, b[i]) for i in range(M)
])
# fast version: vmap with appropriate in_axes
out = vmap(solve_triangular, in_axes=(None, 0))(L, b)
print(jnp.allclose(out, out_slow))
# True |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am looking for for an efficient way to calculate jax.lax.linalg.triangular_solve(L, b), where L = [N N] and b is [M N D], given that M is large, thus I don’t want to make L = [M N N] by copying it using a repeat.
Any help much appreciated.
Beta Was this translation helpful? Give feedback.
All reactions