vmap
over a callable[[float], Array]
#20631
Unanswered
gautierronan
asked this question in
Q&A
Replies: 1 comment 7 replies
-
The arguments to a vmapped function must be arrays, not functions. The fix would be to do something like your second solution, where you are passing arrays rather than functions as arguments to the vmapped function. I'm not entirely sure what else to suggest, because it's not clear to me what your desired outcome is when passing a function in place of an array. |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Let's assume I have a function
f
given by the user (i.e. I am allowed to wrap it, but not allowed to modify it or to know its internals). This function has signaturef(t: float) -> Array
. I would like tojax.vmap
over the returned array of this function.In other words, I would like to achieve this kind of behavior.
For standard arrays, it would work fine.
I'm guessing I should be able to wrap
f
andg
in a certain way to allow this kind ofvmap
? How could I achieve something like that? I know it should involvejax.eval_shape
in some way, but would love any kind of pointers to get me started.Thanks for the help.
Beta Was this translation helpful? Give feedback.
All reactions