Skip to content
Discussion options

You must be logged in to vote

Buongiornissimo.

I achieved something like that in the past by using jax.custom_vjp/jax.custom_jvp (depending on whever you need fwd/bwd differentiation). See here for some examples.

Essentially you can define a function like

@partial(jax.custom_vjp, nondiff_argnums=0)
def my_fancy_fun(integration_mode, integral_params):
    # this can be not differentiable
    return integrate(my_integrand, *integral_params, mode=integration_mode)   

def my_fancy_fun_fwd_pass(integration_mode, integral_params):
    # this can be not differentiable
    fwd_res = integrate(my_integrand, *integral_params) 
    fwd_data = None
    return fwd_res, fwd_data

def my_fancy_fun_fwd_pass(integration_mode, fwd_data, 

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@jschiavon
Comment options

Answer selected by jschiavon
Comment options

You must be logged in to vote
1 reply
@PhilipVinc
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
3 participants