Skip to content
Discussion options

You must be logged in to vote

There's no way in to use vmap directly to map along two axes of an array, but JAX does offer a straightforward way to extract the diagonal of an array. I believe this function should do what you want relatively efficiently:

def vector_compute(X, Y):
  return vmap(jnp.dot)(X, Y.diagonal().T)
assert jnp.allclose(vector_compute(X, Y), loop_compute(X, Y, N))

For this particular operation (dot product along a diagonal), you could also compute it using a single einsum call:

assert jnp.allclose(jnp.einsum('ij,iij->i', X, Y), loop_compute(X, Y, N))

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by mohamad-amin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants