Skip to content
Discussion options

You must be logged in to vote

There's not any way to break a fori_loop; if you need to break based on some condition you can do so using a while_loop. It might look something like this:

from jax import lax

def func(i, x):
  return x + i
print(lax.fori_loop(lower=0, upper=10, body_fun=func, init_val=0))
# 45

def body_fun(carry):
  i, x = carry
  return i + 1, func(i, x)

def cond_fun(carry):
  i, x = carry
  break_condition = (i == 20)
  return ~break_condition & (i < 10)

print(lax.while_loop(cond_fun, body_fun, init_val=(0, 0))[1])
# 45

Replies: 2 comments 11 replies

Comment options

You must be logged in to vote
10 replies
@aespielberg
Comment options

@shailesh1729
Comment options

@aespielberg
Comment options

@patrick-kidger
Comment options

@aespielberg
Comment options

Answer selected by aespielberg
Comment options

You must be logged in to vote
1 reply
@njwfish
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
6 participants