Skip to content
Discussion options

You must be logged in to vote

This wouldn't be straightforwardly expressible in terms of a nested scan or fori_loop, because these require static limits and in your algorithm the limits of the second loop change from iteration to iteration. However, the inner loop can be computed in terms of a vectorized expression:

def quick_mean_jax(R, G, v):
  R, G, v = map(jnp.asarray, (R, G, v))
  N = R.shape[0]
  F = jnp.zeros((N, N)).at[:, 0].set(1)
  for c in range(0, N):
    for t in range(1, c+1):
      u = jnp.arange(1, t + 1)
      F = F.at[c, t].set(1 + (R[c-t+u]*G[u]*v[u]*F[c, t-u]).sum())
  return jnp.diag(F)

To improve it further, I'd probably avoid scan and instead think about whether the whole operation could be vect…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@bhattsamir
Comment options

Answer selected by bhattsamir
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants