A triple loop using fori or scan #11738
-
I am really sorry to be another one to ask a "please help with with a loop" question, but I am making no progress and this might take someone experienced just a few seconds to answer. I should also note I am new-ish to python, so apologies for poor code practice in advance. The raw loop numpy is this for input vectors R,G,v. I am happy to share a fully contained example for this, but really consider R,G and v as any positive vector of size N.
I can eliminate the inner loop via vectorisation. And in jax do something like this maybe?
Can i use fori or scan to remove one of the other loops? I don't even know if it's possible to have these loops efficiently in JAX. Finally I should add, we do know there is a way to do this computation via matrices and this is what we have done in the past (https://arxiv.org/abs/2107.05579 - algorithm 2), but for a few reasons I want to know if it's possible to keep this as loop. I really appreciate any advice or tips from those more experienced and one again apologies for the trivial question. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This wouldn't be straightforwardly expressible in terms of a nested def quick_mean_jax(R, G, v):
R, G, v = map(jnp.asarray, (R, G, v))
N = R.shape[0]
F = jnp.zeros((N, N)).at[:, 0].set(1)
for c in range(0, N):
for t in range(1, c+1):
u = jnp.arange(1, t + 1)
F = F.at[c, t].set(1 + (R[c-t+u]*G[u]*v[u]*F[c, t-u]).sum())
return jnp.diag(F) To improve it further, I'd probably avoid scan and instead think about whether the whole operation could be vectorized. |
Beta Was this translation helpful? Give feedback.
This wouldn't be straightforwardly expressible in terms of a nested
scan
orfori_loop
, because these require static limits and in your algorithm the limits of the second loop change from iteration to iteration. However, the inner loop can be computed in terms of a vectorized expression:To improve it further, I'd probably avoid scan and instead think about whether the whole operation could be vect…