XLA custom call with Eigen #9668
-
Hi team, I'm trying to implement a custom operation (a sparse solver) in Jax using Eigen. The next step in my journey is to use Here is a repository where I am trying to implement a My questions are:
Thanks, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
It might help you going through https://github.com/PhilipVinc/numba4jax or https://github.com/mpi4jax/mpi4jax where we do exactly this kind of things. By quickly looking at your code I see those possible problems:
I'd try to create the eigen matrices manually with |
Beta Was this translation helpful? Give feedback.
-
I'll just point out, for anyone looking to implement custom operations in Eigen, that matrix multiplication and sparse solver primitives with Eigen backends have now been implemented in this repository. Hopefully this code is enough to demonstrate how to link Eigen with JAX. I'm happy to write up a tutorial if that would be helpful. |
Beta Was this translation helpful? Give feedback.
It might help you going through https://github.com/PhilipVinc/numba4jax or https://github.com/mpi4jax/mpi4jax where we do exactly this kind of things.
By quickly looking at your code I see those possible problems:
in_ptr[0]
andin_ptr[1]
toMatrixXd
, but are you sure that is legal?MatrixXd
should contain a pointer to the actual memory, and the shape of the matrix. I doubt that the data that XLA is passing you is aligned the same way that Eigen Expects? Or maybe you do the unpacking correctly, but its quite tricky.out_ptr[0]
to a matrixXd, which for the same reason above, I'm quite sure is wrong.out_ptr[0]
should be a pointer to the bare memory …