-
Hi, I have created a lax.while_loop which is functional with a loop guard which depends on a float and <= such as this: step = 0
while .1 * step <= 10.:
# Do something
step += 1 Here it is: def body(carry):
# Do something
return carry + 1
cond_fun = lambda carry: .1 * carry <= 10.
lax.while_loop(cond_fun, body, 0) In an attempt to make it differentiable, is it possible to also convert this code to use lax.scan? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
The easiest way to convert it to a scan would probably be to use from jax import lax
lax.fori_loop(0, 101, lambda i, x: x + 1, 0) In general, |
Beta Was this translation helpful? Give feedback.
The easiest way to convert it to a scan would probably be to use
lax.fori_loop
, which is implemented viascan
when the start and end ranges are concrete. For your function, it might look like this:In general,
scan
is only applicable when you know a priori how many loop iterations will be required, so it may not be possible depending on what your actual use-case is.