Skip to content
Discussion options

You must be logged in to vote

The input and output of the loop body function must have the same structure, so you can pass a list to fori_loop as long as the return value is a list of values of the same type. For example:

def _loop(i, x):
  y = [val + 1 for val in x]
  return y

x = [1,2,3]
y = jax.lax.fori_loop(0, 10, _loop, x)
print(jax.numpy.array(y))
# [11 12 13]

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by yiminghwang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants