How to best code high order custom derivatives that require l'Hôpital's rule internally? #15514
Replies: 2 comments 2 replies
-
The short answer is that you implement your custom derivative rule in terms of a function with another custom derivative rule defined! I can give more concrete advice if you can share a code snippet that isn't working as expected. |
Beta Was this translation helpful? Give feedback.
-
Inspired by the proposal, I implemented a solution to compute 4th order derivatives of radial basis function kernels k(x,y) = phi(||x-y||) with ||z||^2 := <z, M z> for some positive definite matrix M. Doing all the algebra up to 4th order is quite cumbersome... I thought I share the solution import jax @partial(jax.jit, static_argnums=(0,))
@partial(jax.custom_jvp, nondiff_argnums=(0,1)) @partial(jax.custom_jvp, nondiff_argnums=(0,1)) @partial(jax.custom_jvp, nondiff_argnums=(0,1))
@partial(jax.custom_jvp, nondiff_argnums=(0,1)) def eval_rbf_on_diagonal_4(phi, m, x, y, h4, h3, h2, h1):
@eval_rbf_on_diagonal_0.defjvp @eval_rbf_on_diagonal_1.defjvp
@eval_rbf_on_diagonal_2.defjvp
@eval_rbf_on_diagonal_3.defjvp
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I am implementing methods using radial basis functions in RKHS context and want to use JAX autodiff capabilities.
For a vector z, a (strictly) positive definite matrix M and a scalar function phi: IR -> IR consider the functions
n(z) := Sqrt[<z, M z>]
k(z) := phi(n(z)).
For z != 0, we can compute the gradient
grad(k)(z) = phi'(n(z)) * (M z / n(z)).
Since typically: phi'(0) = 0 and phi''(0) != 0, we can use l'Hôpital (or Taylor expansion of phi' in zero) to define grad(k)(0) = 0 in a continuous fashion. I can implement this in JAX with custom_jvp.
However, I want/need to compute higher order derivatives in my method. The continuation using l'Hôpital's rule is need for any derivative in z=0 as the chain rule continues to produce 1/n(z) terms.
How would I best implement the high order derivatives in JAX and at the same time have them available for JAX' autodiff for further constructions?
An example for phi is Wendland's function(s) (https://num.math.uni-goettingen.de/picap/pdf/E392.pdf). There is also an other topic to be taken care for Wendland's function when differentiating outside the unit ball but this is not the topic of my question.
Also note, if phi(x) = psi(x^2), then the problem disappears if psi is smooth, as n(z)^2 = <z, M z> with gradient equal to 2 M z is also smooth. Composition works nicely in this case. This solves the problem for Gaussian's and inverse multi-quadric kernels say.
Any thoughts/hints for an implementation are highly appreciated.
Beta Was this translation helpful? Give feedback.
All reactions