vmap over axis with different sizes #12207
-
Hello! I have a function that receives as input a tuple of two
But I receive the following error that I don't know how to solve.
How can I specify that I want to iterate over the dimension of size 20 and not the one of size 100? I cannot do padding to make both inputs of same size.
Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Try the following as vmap(my_function, in_axes=((None, 0),))(tup) |
Beta Was this translation helpful? Give feedback.
-
The import jax
import jax.numpy as jnp
def my_function(params):
x, y = params
return x.sum() * y.sum()
params = (
jnp.ones((100, 25)),
jnp.ones((20, 50))
)
out = jax.vmap(my_function, in_axes=[(0, None)])(params)
print(out.shape)
# (100,) |
Beta Was this translation helpful? Give feedback.
Try the following as
my_function
is unary: