Specifying double indices for vmap #8354
-
Hey! I have a function that I want to vectorize using
Now if I wanna rewrite this using
Here, I'm wondering how I should replace the question mark in the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I found a way: import jax
from jax import numpy as np, random, vmap
N = 100
M = 50
key1, key2 = random.split(random.PRNGKey(123))
X = random.normal(key1, (N, M))
Y = random.normal(key2, (N, N, M))
Y.shape
def loop_compute(X, Y, N):
Z = []
for i in range(N):
Z.append(jnp.dot(X[i:i+1], Y[i:i+1, i:i+1].squeeze()))
return jnp.stack(Z).squeeze()
def compute(x, y):
return jnp.dot(x, y)
z_loop = loop_compute(X, Y, N)
z_vector = np.diag(vmap(vmap(compute), in_axes=(None, 0))(X, Y))
print(jnp.allclose(z_vector, z_loop, atol=1e-5)) # True But I'm not sure if this is really efficient or this is the best way to do it, as we're imposing a new |
Beta Was this translation helpful? Give feedback.
-
There's no way in to use def vector_compute(X, Y):
return vmap(jnp.dot)(X, Y.diagonal().T)
assert jnp.allclose(vector_compute(X, Y), loop_compute(X, Y, N)) For this particular operation (dot product along a diagonal), you could also compute it using a single assert jnp.allclose(jnp.einsum('ij,iij->i', X, Y), loop_compute(X, Y, N)) |
Beta Was this translation helpful? Give feedback.
There's no way in to use
vmap
directly to map along two axes of an array, but JAX does offer a straightforward way to extract the diagonal of an array. I believe this function should do what you want relatively efficiently:For this particular operation (dot product along a diagonal), you could also compute it using a single
einsum
call: