Skip to content
Discussion options

You must be logged in to vote

Good question – I think the issue is that your input to vjp_fun is not of the expected size. The VJP is a backward pass, so it expects an input with shape/size corresponding to the cotangent; since you've defined dJ_du is the same shape as y, this should work:

y, vjp_fun = vjp(mapping, x)
print(vjp_fun(dJ_du))

But if you vmap the vjp_fun before calling it on a one-dimensional input, it is the equivalent of passing a scalar to vjp_fun, which will fail.

Does that answer your question?

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@smartalecH
Comment options

Answer selected by smartalecH
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants