-
I presume this has been discussed before but I could not find any reference, so I thought I bring it up as a question. I think in principle import jax
from jax import numpy as jnp
a = jax.random.normal(random.key(9876), shape=(5, 4, 10))
b = jax.random.normal(random.key(976), shape=(1, 10, 4))
# batched matmul
print(jnp.matmul(a, b).shape)
# batched via vmap and repeating b
print(jax.vmap(jnp.dot)(a, jnp.repeat(b, 5, axis=0)).shape)
# why not directly support broadcasting along the mapped axes?
jax.vmap(jnp.dot)(a, b) I can see that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
It's an interesting idea – fundamentally though, broadcasting is a NumPy concept. It's supported in JAX within the If you want something similar to result = jnp.vectorize(jnp.dot, signature="(a,b),(b,c)->(a,c)")(a, b)
print(result.shape) # (5, 4, 4) |
Beta Was this translation helpful? Give feedback.
It's an interesting idea – fundamentally though, broadcasting is a NumPy concept. It's supported in JAX within the
jax.numpy
layer, but not in general supported injax.lax
or other JAX APIs. Implicit broadcasting is sometimes convenient, but for tools likevmap
we think it's better and less error-prone to be explicit.If you want something similar to
vmap
that is broadcast-aware, you can usejnp.vectorize
, which is implemented viavmap
and does support broadcasting. In your example it would look something like this: