-
I have a tree-like data structure and want to do some operations over the tree. However, when I use
A part of the results from running the above code shows that even though the Just wondering what is happening here. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
import jax.numpy as jnp
import numpy as np
from jax import jit
@jit
def f(x):
print('evaluate here!')
return x+1
jnp.where(True, 0, f(1))
print('----')
np.where(True, 0, f(1)) Output is
For this case, you can use |
Beta Was this translation helpful? Give feedback.
-
Thanks for the question! The problem here is that you cannot use JAX runtime conditionals like So when If you want to use recursion in JAX, you can only do so using static (i.e. trace-time) conditionals. For example, something like this would work: @partial(jax.jit, static_argnames=['i'])
def f(x, i):
if i <= 0:
return x
return f(i * x, i - 1)
f(1.0, 10) but something like this, though it looks equivalent, leads to an infinite loop during tracing: @jit
def f(x, i):
return jnp.where(i <= 0, x, f(i * x, i - 1)) |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
The problem here is that you cannot use JAX runtime conditionals like
jnp.where
,lax.switch
,lax.cond
etc. in a recursive fashion. The reason is that JAX's tracing mechanism must abstractly evaluate the output, and there's no way for this abstract evaluation to terminate if the termination condition is dynamic.So when
operate_helper_add1
recursively calls itself, it necessarily leads to an infinite loop. I would suggest finding a way to express your computation that does not involve recursion, perhaps usinglax.while_loop
instead.If you want to use recursion in JAX, you can only do so using static (i.e. trace-time) conditionals. For example, something like this …