-
Hello! Pretty new to JAX and I'm trying to iteratively apply a function a variable number of times depending on the time difference between two points but I keep running into 'jax._src.errors.ConcretizationTypeError' errors regardless of what I try. My main intent is to do something similar to the Euler Method where I have a sample at time t0 and want to estimate a value at time tn with step size h (specifically where t0 and tn are not evenly spaced). Using simpler tools than JAX I would just throw this in a for loop but for my purposes I need to differentiate through this operation and from what I can tell that requires use of Included are three reproducible examples of the issue I'm running into. #!/usr/bin/env python3
import numpy as onp
import jax.numpy as jnp
import jax
def plus_one(elem):
return elem + 1
@jax.jit
def foo1(t0, z0, tn):
steps = jnp.arange(t0, tn, 0.01)
def body_fun(carry, xs):
out = z0 + steps * plus_one(z0)
return out, xs
#
res,_ = jax.lax.scan(body_fun, z0, None, length=len(steps))
return results
@jax.jit
def foo2(t0, z0, tn):
num_steps = jnp.divide(jnp.subtract(tn, t0), 0.01).astype(int)
#num_steps = jnp.array(jnp.divide(jnp.subtract(tn, t0), 0.01), int) #also fails in a similar way
def body_fun(carry, xs):
out = z0 + 0.01 * plus_one(z0)
return out, xs
#
res,_ = jax.lax.scan(body_fun, z0, None, length=num_steps)
return results
@jax.jit
def foo3(t0, z0, tn, num_steps):
def body_fun(carry, xs):
out = z0 + ((tn-t0)/num_steps) * plus_one(z0)
return out, xs
#
res,_ = jax.lax.scan(body_fun, z0, None, length=num_steps)
return results
step_size = 0.01
time_zeros = jnp.array(onp.random.gamma(1,1,20))
z0s = jnp.array(onp.random.normal(0,1,20))
time_ns = time_zeros + jnp.array(onp.random.gamma(1,2,20))
steps = ((time_ns - time_zeros)/step_size).astype(int)
jax.vmap(foo1, in_axes=(0,0,0))(time_zeros, z0s, time_ns)
jax.vmap(foo2, in_axes=(0,0,0))(time_zeros, z0s, time_ns)
jax.vmap(foo3, in_axes=(0,0,0,0))(time_zeros, z0s, time_ns, steps) If anyone knows how to dynamically adapt scan iteration length (or if its even possible) I'd really appreciate any help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Beta Was this translation helpful? Give feedback.
lax.scan
can only scan for a static number of steps. To do a variable number of steps, you should trylax.while_loop
.