How to expose C++ class methods as jax primitives #15583
Replies: 1 comment 1 reply
-
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
The goal
I have been following these two tutorials [1] [2] for exposing my own C++ class methods to jax. The ultimate idea is to create an operator that could be cast into numpyro's NUTS / HMC samplers. There was already a lot of development done on the C++ side, so I just want to make a wrapper for some of the objects on the C++ side and expose them through pybind to python and then create a jax Primitive out of them.
The problem
I have encountered a problem where I need to pass a class object pointer from the python side to C++ side, so that my XLA
custom_call
can properly call the corresponding class method, with the correct instance of the class. The code for this custom call looks like this:The reason why I need to pass the class object (pointer of type
pyhelperWrapper
) is that I also want the class instance to be passed to C++ side, because some class variables are already properly initialized on the python side, and I just want to pass that object to C++ in order to properly evaluate the line*result = phW->compute_like(*pyshat);
.Therefore, I am wondering is it even possible to pass around through JAX primitive's arguments pointers to class objects?
P.S. Note that I needed to make the function
static
here, given that the pointer to this function is passed to thepybind::capsule
which will raise an error if it is a non-static class member, becausethis
pointer would be passed by default too.More details on the problem
For some more context, when trying to make the jax primitive on the python side, the first step, as far as I understand, is to tell jax how to bind the inputs, i.e.:
Where here the
pyleftyobj
is the pointer to the instantiated object. And of course, jax doesn't know how to handle this class object pointer and hence gives the following error:Here is also an example of how I am testing my implementation so far and which line causes the above error (see
#FIXME
):Follow up problem
Another problem that will arise as well I guess, is how to treat the
mlir
lowering step, when performing thejaxlib.hlo_helpers.custom_call
, since as far as I understand I would need to do something like:Here I am pretty sure the
sys.getsizeof(pyleftyobj)
would probably not be the correct thing to pass, given how thehlo.CustomCallOp
is structured, so I just guessed this would be the structure of the call. However, first I need to solve the above problem of passing class object pointer in order to deal with this one down the road..Possible solutions I am not satisfied with
XLACPU_custom_call
is called on the C++ side. In other words pass to thepylefty_ll
these arguments. But this doesn't seem like a practical solution.Beta Was this translation helpful? Give feedback.
All reactions