Skip to content
Discussion options

You must be logged in to vote

The issue with the infinite recursion is that when you call jvp(EulerIntegrator), it calls the custom jvp rule, which calls jvp(EulerIntegrator), which calls the custom jvp rule... etc.

It's not clear from your code what you expected to happen... if you were hoping to return the default JVP rule for euler_integrator you could do so by maintaining a copy of the non-custom-jvp function, something like this:

def _euler_integrator(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
  ... # main implementation here

euler_integrator = jtu.Partial(jax.custom_jvp, nondiff_argnums=(0,))

@euler_integrator.defjvp
def euler_integrator_jvp(f, primals, tangents):
    
    print('Raise a w…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@danielkelshaw
Comment options

Answer selected by danielkelshaw
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants