-
Hi all, What would be the best way to implement a JIT'able version of this example, where
This works, but takes a very long time to trace (for large |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
If you want to avoid long compile times, you should avoid using Python In this case, I'd probably avoid that altogether and instead express your computation in a vectorized manner – then it will be both jit-compatible and performant. For example: @partial(jax.jit, static_argnames=['n'])
def f(n, a):
i = jnp.arange(n)[:, None]
j = jnp.arange(n)[None, :]
x = jnp.zeros((n, n))
x = x.at[i, j].set(jnp.exp(-(i - j)) * a[j])
return jnp.where(i <= j, 0, x).sum(1) |
Beta Was this translation helpful? Give feedback.
If you want to avoid long compile times, you should avoid using Python
for
loops. One strategy is to replace them withlax.fori_loop
, but the length of such loops must be static (so the size of the inner loop could not depend on the counter of the outer loop).In this case, I'd probably avoid that altogether and instead express your computation in a vectorized manner – then it will be both jit-compatible and performant. For example: