-
Hello, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The way to support custom CUDA calls would be via a Essentially, you need to define a function that computes the gradient for your kernel call, and register it so that the JAX autodiff machinery knows to use that function when it encounters your Cuda kernel during tracing. Hopefully that will point you in the right direction! |
Beta Was this translation helpful? Give feedback.
The way to support custom CUDA calls would be via a
custom_jvp
, which you can read about here: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.htmlEssentially, you need to define a function that computes the gradient for your kernel call, and register it so that the JAX autodiff machinery knows to use that function when it encounters your Cuda kernel during tracing. Hopefully that will point you in the right direction!