scipy.linalg.sparse.onenormest does not translate to idiomatic Jax #12102
Unanswered
williamberman
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
X-posting from issue for implementing scipy.linalg.sparse.onenormest
The sparse matrix one norm estimation algorithm is control flow heavy with many conditional variable updates and loop breaks.
In order to get an implementation to jit compile, I had to resort to manual continuation passing and tricks to keep
matrices of constant dimension in the main loop body.
Now the function does jit and sees performance improvements from the scipy implementation when running on GPU -- ~8x for 4096x4096 matrices with some basic benchmarking.
Additionally, while the implementation isn't pretty, it is relatively straightforward to see how it maps to the original paper and existing implementations.
The downside is that the implementation is very clearly non-idiomatic jax.
Now it is likely possible to re-write some of the algorithm to be more functional and written in more idiomatic jax. However, the more it deviates from the original spec, the harder it's going to be to implement correctly.
In summary, the algorithm does not obviously translate to idiomatic jax and balancing a jitted correct implementation against some re-written portions with more idiomatic jax is non-trivial.
I see a few options
Stay as close to scipy implementation as possible with jitting as an afterthought.
Stay close to the scipy implementation with minor modifications to jit portions of the function.
Jit the full function while staying close to the scipy implementation -- results in non-idiomatic jax but is easier to verify correct than 4
Jit the full function by refactoring the original algorithm to write in more idiomatic jax
Has anyone seen a similar issue before with thoughts on what makes the most sense?
Octave implementation
Scipy implementation
Paper describing algorithm
My implementation
Benchmarks vs scipy
Beta Was this translation helpful? Give feedback.
All reactions