Pallas: work around slicing by shifting pointers #20408
Unanswered
gautierronan
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to implement a so-called Sparse DIA matrix multiplication kernel, which stores sparse square matrices using their diagonals. Here is a typical implementation in plain JAX:
which achieves the required matmul, see e.g. this example:
When going to Pallas, there is no slicing, so I am trying to workaround it. The workaround I found was to shift pointers by
k
so that storage in the result array would happen at the right indices. Here is my attempt at a kernel:However, this does not seem to work. It yields a result without erroring out, but I believe something is wrong with the initialization (the first line
z_ref[:] = jnp.zeros_like(y_ref)
is not properly taken into account?).Any idea of how I can achieve this? Any help on how to progress is greatly appreciated. Thanks!
cc @chr1sj0nes @superbobry
Beta Was this translation helpful? Give feedback.
All reactions