vmap with non vmappable parameter #8297
Answered
by
jakevdp
aniquetahir
asked this question in
Q&A
-
Is there a way to create a vmap inside a jitted function where one of the parameters is broadcast across the vmapped function? e.g. def inner_function(x, y):
return x*y
@jit
def outer_function(X, y_):
vmapped_function = vmap(inner_function)
return vmapped_function(X, y_) # here y, I don't want y to be vmapped
@jit
def outer_outer_function(XX, Y):
vmapped_function = vmap(outer_function)
return vmapped_function(XX, Y) # here Y is vmapped Dim XX: (10, 10, 10, 1) Basically, here I don't want |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Oct 20, 2021
Replies: 1 comment 1 reply
-
I think def inner_function(x, y):
return x*y
@jit
def outer_function(X, y_):
vmapped_function = vmap(inner_function, in_axes=(0, None))
return vmapped_function(X, y_) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
aniquetahir
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think
in_axes
is what you're looking for: