Best way to do a “shifted” matrix multiplication #21072
Unanswered
marcofrancis
asked this question in
Q&A
Replies: 1 comment
-
I suspect the best way to accomplish this is to first shift the 2D matrix, then perform the full reduction via t = jnp.arange(T)[:, None]
a = jnp.arange(J)
g_shifted = g[jnp.maximum(t - a, 0), a]
s = jnp.einsum("nta,ta->t", f, g_shifted) Compared to the original solution, this reduces the number of index operations into |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have two arrays say f and g, f is N by T by J dimensional and f is T by J dimensional. I’m trying to compute the following in JAX (for all 0<=t<T):
Notice that if t-a<0 I’d like it to default to 0.
What would be the fastest approach?
Right now I create a list of all possible indexes, multiply elementwise the two arrays evaluated in the relevant indexes and sum them up:
This does not seem particularly efficient nor elegant and I would appreciate a better solution.
Beta Was this translation helpful? Give feedback.
All reactions