Custom JVP with non-differentiable PyTree #16838
-
Hi, I'm currently trying to inject some orthogonal behaviour into a custom vjp. Essentially, I would like to raise a warning when someone differentiates through the solution to an initial value problem to state the adjoint method has not been implemented -- pointing them to another implementation which does this. The code I'm starting from is below. from typing import Callable, Generic, TypeAlias, TypeVar
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from chex import dataclass
T = TypeVar('T')
@dataclass
class ParametersIVP(Generic[T]):
differential_operator: Callable[[T], T]
dt: float = 1e-3
n_steps: int = int(1e3)
def _merge_states(preceding: T, final: T) -> T:
return jtu.tree_map(lambda preceding, final: jnp.concatenate([preceding, jnp.expand_dims(final, 0)]), preceding, final)
def euler_integrator(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
"""Forward-Euler method for integration of the initial value problem.
Parameters:
ivp_params: Parameters for the initial value problem.
initial_state: State at t=0 to integrate from.
Returns:
final_state: Final state of the initial value problem.
full_state: Entire solution of the initial value problem.
"""
def _single_step(state: T, _: None) -> tuple[T, T]:
update = ivp_params.differential_operator(state)
next_state = jtu.tree_map(lambda x, dxdt: x + dxdt * ivp_params.dt, state, update)
return next_state, state
final_state, preceding_states = jax.lax.scan(_single_step, initial_state, None, ivp_params.n_steps)
full_state = _merge_states(preceding=preceding_states, final=final_state)
return final_state, full_state I modify this code to the following in order to try and inject a print statement: @jtu.Partial(jax.custom_jvp, nondiff_argnums=(0,))
def euler_integrator(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
def _single_step(state: T, _: None) -> tuple[T, T]:
update = ivp_params.differential_operator(state)
next_state = jtu.tree_map(lambda x, dxdt: x + dxdt * ivp_params.dt, state, update)
return next_state, state
final_state, preceding_states = jax.lax.scan(_single_step, initial_state, None, ivp_params.n_steps)
full_state = _merge_states(preceding=preceding_states, final=final_state)
return final_state, full_state
@euler_integrator.defjvp
def euler_integrator_jvp(f, primals, tangents):
print('Raise a warning here...')
return jax.jvp(jtu.Partial(euler_integrator, f), primals, tangents, has_aux=True) As a simple test-case, I integrate the Lorenz63 system, as follows: @jax.jit
def lorenz63(x: jax.Array) -> jax.Array:
return jnp.array([
10.0 * (x[1] - x[0]),
x[0] * (28 - x[2]) - x[1],
x[0] * x[1] - (8 / 3) * x[2]
])
ivp_params = ParametersIVP(differential_operator=lorenz63)
x0 = jnp.array([1.0, 0.0, 0.0])
jac, aux = jax.jacobian(euler_integrator, argnums=1, has_aux=True)(ivp_params, x0) I get an error, I was wondering if anyone had come across anything similar or if there was a simple fix for this problem? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The issue with the infinite recursion is that when you call It's not clear from your code what you expected to happen... if you were hoping to return the default JVP rule for def _euler_integrator(ivp_params: ParametersIVP[T], initial_state: T) -> tuple[T, T]:
... # main implementation here
euler_integrator = jtu.Partial(jax.custom_jvp, nondiff_argnums=(0,))
@euler_integrator.defjvp
def euler_integrator_jvp(f, primals, tangents):
print('Raise a warning here...')
return jax.jvp(jtu.Partial(_euler_integrator, f), primals, tangents, has_aux=True) That should avoid the infinite recursion. Does that do what you had in mind? |
Beta Was this translation helpful? Give feedback.
The issue with the infinite recursion is that when you call
jvp(EulerIntegrator)
, it calls the custom jvp rule, which callsjvp(EulerIntegrator)
, which calls the custom jvp rule... etc.It's not clear from your code what you expected to happen... if you were hoping to return the default JVP rule for
euler_integrator
you could do so by maintaining a copy of the non-custom-jvp function, something like this: