Skip to content
Discussion options

You must be logged in to vote

I don't know of any built-in operation that allows tensor dot products with dynamic axes. Keep in mind that because of JAX's static shape requirements, this would only be possible in cases where the axes being selected all have the same size, so that the output will have the same shape regardless of the dynamic value.

Your best bet would probably be to do it manually using a cond or similar; for example:

import jax
import jax.numpy as jnp

@jax.jit
def dynamic_dot(x, y, axis):
  def f1(x, y):
    return x.T @ y
  def f2(x, y):
    return x @ y
  return jax.lax.cond(axis==0, f1, f2, x, y)

x = jnp.arange(9).reshape(3, 3)
y = jnp.ones(3)

print(dynamic_dot(x, y, axis=0))  # [ 9. 12. 15.]
print

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by erick-xanadu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants