-
Hi everyone, I am writing a custom VJP rule and it is raising an error that I do not understand. Let me explain:
The code to compute import jax
import jax.numpy as jnp
from functools import partial
from collections import namedtuple
from jax import lax
def _clip_to_end(tnext, t1):
# adapted from the Diffrax library
if tnext.dtype is jnp.dtype("float64"):
tol = 1e-10
else:
tol = 1e-5
clip = tnext > t1 - tol
return jnp.where(clip, t1, tnext)
def _get_number_steps(t0, t1, step_size):
# adapted from the Diffrax library
"""
A function that computes the number of steps between t0 and t1
with ah given step size.
"""
carry = namedtuple("carry", "t t1 step_size nb_step")
def _cond_fn(C):
t, t1, step_size, nb_step = C
return t < t1
def _body_fn(C):
_t, t1, step_size, nb_step = C
return carry(_clip_to_end(_t + step_size, t1), t1, step_size, nb_step + 1)
t0, t1 = lax.cond(t0 <= t1, lambda _: (t0, t1), lambda _: (t1, t0), None)
c = carry(t0, t1, step_size, 1)
last_t = lax.while_loop(_cond_fn, _body_fn, c)
return last_t.nb_step Then the function we are interested in looks like this: def _f_test(x, t0,t1,step_size):
with jax.ensure_compile_time_eval():
N = _get_number_steps(
t0, t1, step_size
)
l = jnp.arange(N)
s = jnp.sum(l)
return x**s
@partial(jax.jit, static_argnums = (1,2,3))
def _f(x,t0,t1,step_size):
jax.debug.print("t0 = {x}, t1 = {y}, step = {z}", x=t0, y=t1, z=step_size)
return _f_test(x,t0,t1,step_size)
print(_f(2., 0. , 0.1, 0.01)) And this works like a charm, no problem here ! The problem arises if we derive a @jax.custom_vjp
def _f_test(x, t0,t1,step_size):
with jax.ensure_compile_time_eval(): # Ensure that the computation is done, even though it is compiled.
N = _get_number_steps(
t0, t1, step_size
) # Static value is needed for linspace when we jit the function.
l = jnp.arange(N) # ConcretizationTypeError here
s = jnp.sum(l)
return x**s
def _f_test_fwd(x, t0,t1,step_size):
return _f_test(x, t0, t1, step_size), (x, t0, t1, step_size)
def _f_test_bwd(residual, g):
x, t0, t1, step_size = residual
with jax.ensure_compile_time_eval(): # Ensure that the computation is done, even though it is compiled.
N = _get_number_steps(
t0, t1, step_size
) # Static value is needed for linspace when we jit the function.
l = jnp.arange(N)
s = jnp.sum(l)
return g*s*x**(s-1), None, None, None
_f_test.defvjp(_f_test_fwd, _f_test_bwd)
@partial(jax.jit, static_argnums = (1,2,3))
def _f(x,t0,t1,step_size):
jax.debug.print("t0 = {x}, t1 = {y}, step = {z}", x=t0, y=t1, z=step_size)
return _f_test(x,t0,t1,step_size)
print(_f(2., 0. , 0.1, 0.01)) This will return a Thank you for any input ! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Got an answer here #12047 |
Beta Was this translation helpful? Give feedback.
Got an answer here #12047