static_argnums for lax.fori_loop #10516
-
I have the following function from jax import jit, lax
import jax.numpy as jnp
def multiple_iter():
def single_iter(i, data):
data = jnp.pad(data, i)
return data
data = jnp.zeros((5, 5))
lax.fori_loop(
1, 5, single_iter, data
)
multiple_iter() This gives I tried to use from jax import jit, lax
import jax.numpy as jnp
def multiple_iter():
def single_iter(i, data):
data = jnp.pad(data, i)
return data
data = jnp.zeros((5, 5))
single_iter_jit = jit(single_iter, static_argnums=0)
lax.fori_loop(
1, 5, single_iter_jit, data
)
multiple_iter() After the modification, I get
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
fori_loop
will trace the loop body with tracer, thussingle_iter
will be called with all tracer argument.Moreover,
fori_loop
cannot change the shape ofval
.You should use python for loop in such case.