raise ConcretizationTypeError when use fori_loop #12551
-
Hi, there I am going to use fori_loop, but it raises the ConcretizationTypeError like,
the simplified code is shown as follows, import jax
def random_binary_array(key, n):
key1, key2 = jax.random.split(key, num=2)
return key1, jax.random.choice(key2, jnp.asarray([0, 1]), shape=(n,))
def loop(i, s):
key1, key2 = jax.random.split(s["key"], num=2)
s["key"], bin_array = random_binary_array(key2,s["n"])
return s
n = 5
key = jax.random.PRNGKey(0)
key, bin_array = random_binary_array(key, 5)
s = {
"n": n,
"key": key,
"bin_array": bin_array
}
new_bin_array = jax.lax.fori_loop(0,5,loop,s) does anyone know what is the reason for this error and how to fix it? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Within the To fix this, you'll need to use a static value for s["key"], bin_array = random_binary_array(key2, s['n']) to this: s["key"], bin_array = random_binary_array(key2, n) Of course this assumes that |
Beta Was this translation helpful? Give feedback.
Within the
loop
function, all arguments will be traced. This means thats['n']
will be traced, and so when you writeshape=(n,)
you are requesting a dynamicly-shaped array, which is not currently supported in JAX.To fix this, you'll need to use a static value for
n
. One way to do this would be to use the globaln
value; that is change this line:to this:
Of course this assumes that
n
is not changing from iteration to iteration; if it is, then you really do have dynamic array sizes and you'll have to find a different approach to your algorithm within JAX.