Skip to content
Discussion options

You must be logged in to vote

It's an interesting idea – fundamentally though, broadcasting is a NumPy concept. It's supported in JAX within the jax.numpy layer, but not in general supported in jax.lax or other JAX APIs. Implicit broadcasting is sometimes convenient, but for tools like vmap we think it's better and less error-prone to be explicit.

If you want something similar to vmap that is broadcast-aware, you can use jnp.vectorize, which is implemented via vmap and does support broadcasting. In your example it would look something like this:

result = jnp.vectorize(jnp.dot, signature="(a,b),(b,c)->(a,c)")(a, b)
print(result.shape)  # (5, 4, 4)

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@adonath
Comment options

Answer selected by adonath
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
2 participants