Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

The problem here is that you cannot use JAX runtime conditionals like jnp.where, lax.switch, lax.cond etc. in a recursive fashion. The reason is that JAX's tracing mechanism must abstractly evaluate the output, and there's no way for this abstract evaluation to terminate if the termination condition is dynamic.

So when operate_helper_add1 recursively calls itself, it necessarily leads to an infinite loop. I would suggest finding a way to express your computation that does not involve recursion, perhaps using lax.while_loop instead.

If you want to use recursion in JAX, you can only do so using static (i.e. trace-time) conditionals. For example, something like this …

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Comment options

You must be logged in to vote
2 replies
@dui1234
Comment options

@jakevdp
Comment options

Answer selected by dui1234
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants