You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi Jax team, I am reworking the Jax interface for PennyLane a quantum computing framework. I am registering custom JVPs function for evaluating gradient, jacobian for techniques such as finite differences and parameter shift. For first derivative everything is working like expected. I am hitting some issues when I try to implement an higher order derivative, e.g. Hessian. I am using recursion to define the n-th derivative rule, here you can find a pseudo code of the implementation:
When I try to take the hessian jax.hessian(circuit, argnums = [0, 1, ...])(params, max_diff=2) : I get the following error.
jax.interpreters.ad.CustomJVPException: Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn't supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters. Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.
Could you tell me what is a closed-over value? Are those tracers that were used for backprop and that cannot be differentiated over? Is there an existing example on how to solve this kind of issue.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Jax team, I am reworking the Jax interface for PennyLane a quantum computing framework. I am registering custom JVPs function for evaluating gradient, jacobian for techniques such as finite differences and parameter shift. For first derivative everything is working like expected. I am hitting some issues when I try to implement an higher order derivative, e.g. Hessian. I am using recursion to define the n-th derivative rule, here you can find a pseudo code of the implementation:
When I try to take the hessian
jax.hessian(circuit, argnums = [0, 1, ...])(params, max_diff=2)
: I get the following error.Could you tell me what is a closed-over value? Are those tracers that were used for backprop and that cannot be differentiated over? Is there an existing example on how to solve this kind of issue.
Thank you all in advance!
Beta Was this translation helpful? Give feedback.
All reactions