Skip to content
Discussion options

You must be logged in to vote

vmap refers to the axis of the input arrays, not to the function that is mapped. So the relevant axes here are not anything associated with dfdx2, but rather the axes of the input argument x.

The input argument is x, which is of shape [100, 1]. The first vmap maps over the first axis of length 100; the second nested vmap maps over the second axis of length 1; the third nested vmap doesn't have any axis to map over, therefore it raises this error.

From your example, it's not entirely clear to me what you're trying to compute (what exactly is a triply nested vmapped jacobian? I'm honestly not sure). That said, if you want to compute the element-wise third derivative of a scalar function app…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jdongg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants