Skip to content
Discussion options

You must be logged in to vote

The way to support custom CUDA calls would be via a custom_jvp, which you can read about here: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

Essentially, you need to define a function that computes the gradient for your kernel call, and register it so that the JAX autodiff machinery knows to use that function when it encounters your Cuda kernel during tracing. Hopefully that will point you in the right direction!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@mah-asd
Comment options

Answer selected by mah-asd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants