-
I'd like to perform a bunch of einsums for arrays with different shapes. I tried to use
and I got the error Am I doing it wrong or is there a workaround? Thank you in advance. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question. For what you have in mind, I'd suggest a simple Python loop. for example: import jax
from jax import numpy as jnp
aa = [jnp.zeros((5,5,2)), jnp.zeros((5,5,3))]
bb = [jnp.zeros((2,)), jnp.zeros((3,))]
result = [jnp.einsum('ijk,k->ij', a, b) for a, b in zip(aa, bb)] This will be compatible with both forward and reverse autodiff. If your lists are large, then wrapping the outer function in Alternatively, you may be able to pad your values so that they can all be stored in a single array, then you could use Hope that helps! |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question.
scan
can only be used to scan over JAX arrays, not Python lists. Since JAX does not have support for ragged arrays, it means that your use-case cannot be implemented viascan
.For what you have in mind, I'd suggest a simple Python loop. for example:
This will be compatible with both forward and reverse autodiff. If your lists are large, then wrapping the outer function in
jit
will likely lead to long compile times, because thefor
loop is flattened into a long sequence …