how to use jax.lax.fori_loop with list argument? #12405
Answered
by
jakevdp
yiminghwang
asked this question in
Q&A
-
Hi, there, when I try to use jax.lax.fori_loop with a list argument, it returns TypeError. The following is an example, def _loop(x):
# some process code
return y
x = [1,2,3]
y = jax.lax.fori_loop(0, 10, _loop, x) How can I use the list variable as the init_val. |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Sep 19, 2022
Replies: 1 comment
-
The input and output of the loop body function must have the same structure, so you can pass a list to def _loop(i, x):
y = [val + 1 for val in x]
return y
x = [1,2,3]
y = jax.lax.fori_loop(0, 10, _loop, x)
print(jax.numpy.array(y))
# [11 12 13] |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
yiminghwang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The input and output of the loop body function must have the same structure, so you can pass a list to
fori_loop
as long as the return value is a list of values of the same type. For example: