Skip to content
Discussion options

You must be logged in to vote

The easiest way to convert it to a scan would probably be to use lax.fori_loop, which is implemented via scan when the start and end ranges are concrete. For your function, it might look like this:

from jax import lax
lax.fori_loop(0, 101, lambda i, x: x + 1, 0)

In general, scan is only applicable when you know a priori how many loop iterations will be required, so it may not be possible depending on what your actual use-case is.

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@jakevdp
Comment options

@mjhoover1
Comment options

@mjhoover1
Comment options

@jakevdp
Comment options

@mjhoover1
Comment options

Answer selected by mjhoover1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants