jax.vmap output: an explanation is required to get a 2D output array. #14054
Unanswered
jecampagne
asked this question in
Q&A
Replies: 2 comments 5 replies
-
Roughly, shape signature of the functions are: # postulate
# X: tuple[int, ...]
# n: int
# x, y: ndarray
# x.shape == y.shape
# X, X -> ()
# (n, *X), (n, *X) -> (n,) under vmap
def fext1(x, y):
return jnp.asarray(-1.)
# X, X -> X
# (n, *X), (n, *X) -> (n, *X) under vmap
def fext2(x, y):
return -jnp.ones_like(x)
# X, X -> X
# (n, *X), (n, *X) -> (n, *X) under vmap
def fext3(kpos):
return x + y |
Beta Was this translation helpful? Give feedback.
0 replies
-
So, if I want to make f(x,y) returning a number which is the result of computations with kpos.x and kpos.y, I should multiply it by jnp.ones_like(kpos.x) to get it right? |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
Here is a snippet that exhibit different vmap outputs, with one that I do not expect and I do not manage to get it right neither. So I need help. Thanks
Then with
I get
which I do not expect, while with
I get what I was expecting as a 10x10 array
and the fext3 is ok too.
So, it seems after some other trials that the output of the function fext(kpos) should be of a certain type, but I do not manage to get it the correct one. Any idea is welcome.
Beta Was this translation helpful? Give feedback.
All reactions