Skip to content
Discussion options

You must be logged in to vote

I think the short answer is that lax.fori_loop is designed for truly sequential cases, where each individual iterate depends on the previous step, and the step function is reasonably complex (think a step of gradient descent, or an action in a reinforcement learning context). If jit were to statically unroll e.g. 1000 steps of gradient descent it would take a huge amount of time for XLA to actually compile the function. In the fori_loop context XLA is encouraged (possibly required? Disclaimer: I'm not an XLA expert) to treat each iterate as its own individual black box, so can't make any simplifications/speedups by e.g. batching individual multiplies into a faster matrix multiply.

In your…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@gaspardbb
Comment options

Answer selected by gaspardbb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants