Skip to content
Discussion options

You must be logged in to vote

If you want to avoid long compile times, you should avoid using Python for loops. One strategy is to replace them with lax.fori_loop, but the length of such loops must be static (so the size of the inner loop could not depend on the counter of the outer loop).

In this case, I'd probably avoid that altogether and instead express your computation in a vectorized manner – then it will be both jit-compatible and performant. For example:

@partial(jax.jit, static_argnames=['n'])
def f(n, a):
  i = jnp.arange(n)[:, None]
  j = jnp.arange(n)[None, :]
  x = jnp.zeros((n, n))
  x = x.at[i, j].set(jnp.exp(-(i - j)) * a[j])
  return jnp.where(i <= j, 0, x).sum(1)

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@rdaems
Comment options

@rdaems
Comment options

@jakevdp
Comment options

@rdaems
Comment options

Answer selected by rdaems
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants