Integration in jax #8327
-
Hi guys, For my work I need to optimize (i.e. diffferentiate) a function which has an integral inside (quite simple, as it is one-dimensional) that depends on one of the variable I want to differentiate on. If this is the case, I might have to implement a couple of methods myself (probably simple quadrature is enough for my needs), but I hope not to! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Buongiornissimo. I achieved something like that in the past by using Essentially you can define a function like @partial(jax.custom_vjp, nondiff_argnums=0)
def my_fancy_fun(integration_mode, integral_params):
# this can be not differentiable
return integrate(my_integrand, *integral_params, mode=integration_mode)
def my_fancy_fun_fwd_pass(integration_mode, integral_params):
# this can be not differentiable
fwd_res = integrate(my_integrand, *integral_params)
fwd_data = None
return fwd_res, fwd_data
def my_fancy_fun_fwd_pass(integration_mode, fwd_data, fwd_res):
# this can also be not differentiable
def my_grad(vector)...
return my_grad
my_fancy_fun.defvjp(my_fancy_fun_fwd_pass, my_fancy_fun_fwd_pass) in the body of the forward and backward pass you can use anything even if it is not differentiable, but it must be jax compatible. Now, if you want to call back into standard scipy, my usual trick is to use def integrate(f, *pars, mode="something"):
# f and mode must be hashable constants
from jax.experimental import host_callback as hcb
return hcb.call(partial(_integrate, f, mode=mode), pars, result_shape=jax.ShapeDtypeStruct((), jnp.float64)
def _integrate(f, *pars, mode):
return scipy.integrate(f, pars, mode) This calls back into python. You must be able to tell XLA the result_shape (so shape and dtype) of the output. |
Beta Was this translation helpful? Give feedback.
-
@PhilipVinc Do you know of a "vmappable" method? I get the following error when I wrap my function with
I am somewhat surprised that |
Beta Was this translation helpful? Give feedback.
Buongiornissimo.
I achieved something like that in the past by using
jax.custom_vjp/jax.custom_jvp
(depending on whever you need fwd/bwd differentiation). See here for some examples.Essentially you can define a function like