Break a fori_loop #8375
-
Is there any way to break out of a fori_loop early? I understand why jax would need to precompile a maximum number of loop iterations, but is there any way to have it run for only a subset of those loop iterations at runtime? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 11 replies
-
There's not any way to break a from jax import lax
def func(i, x):
return x + i
print(lax.fori_loop(lower=0, upper=10, body_fun=func, init_val=0))
# 45
def body_fun(carry):
i, x = carry
return i + 1, func(i, x)
def cond_fun(carry):
i, x = carry
break_condition = (i == 20)
return ~break_condition & (i < 10)
print(lax.while_loop(cond_fun, body_fun, init_val=(0, 0))[1])
# 45 |
Beta Was this translation helpful? Give feedback.
-
Sorry, meant to post as a proper answer, not just in the comments: [TL;DR links and explanation for an implementation of a reverse-mode differentiable bounded while loop.] |
Beta Was this translation helpful? Give feedback.
There's not any way to break a
fori_loop
; if you need to break based on some condition you can do so using awhile_loop
. It might look something like this: