You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all, I want to see a use-case like the following is possible: I have a function defined in JAX Python along with its gradient (through autodiff). I want to implement a fast operation involving this function (one example is integration) that would need to use this JAX function and its gradient as a black-box.
I would like this operation to be as fast as possible, so I would like to implement a custom OP in CUDA/C++ to do it, but I would like to be able to call the JAX function from within this operation.
Is this possible at all? I am willing to do some digging and plumbing to get it to work if needed, but could someone point me in the right direction?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all, I want to see a use-case like the following is possible: I have a function defined in JAX Python along with its gradient (through autodiff). I want to implement a fast operation involving this function (one example is integration) that would need to use this JAX function and its gradient as a black-box.
I would like this operation to be as fast as possible, so I would like to implement a custom OP in CUDA/C++ to do it, but I would like to be able to call the JAX function from within this operation.
Is this possible at all? I am willing to do some digging and plumbing to get it to work if needed, but could someone point me in the right direction?
Thank you!
Deniz
Beta Was this translation helpful? Give feedback.
All reactions