Gradient only defined for scalar-output functions. Output had shape: (1,). #22201
Unanswered
bengladwyn
asked this question in
Q&A
Replies: 1 comment
-
Thanks for the question! The issue is that your function returns a one-dimensional vector, and def root_function(phi, h):
return hofphi(phi)[0] - h Also, side note: in both cases where you use list comprehensions, it will generally be more efficient to use from jax import vmap
# V_h_values = [V_phi(h) for h in h_values]
V_h_values = vmap(V_phi)(h_values)
# V_h_gradient_values = [gradient_V_h(h) for h in h_values]
V_h_gradient_values = vmap(gradient_V_h)(h_values) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to calculate the gradient of a function V(h) which is defined via V(phi) and h(phi). I therefore need to invert h(phi) to find phi(h) which I do with a minimize function. Then, to find V(h) I calculate V(phi(h)). This works to successfully plot the function but I get the following error when trying to calculate the gradient:
My code is as follows
Beta Was this translation helpful? Give feedback.
All reactions