-
Learned from @yonghakim, I utilized code below:
to conduct gradient calculation on degenerated eigenvectors (which I've cried for help for few days ago #19030) . Though the gradient is accessible, when I tried to optimize parameters, error occured:
I'm new to JAX and my question might be ambiguous or inappropriate, but still I would really appreciate it if anyone can offer me cludes. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question! In JAX, forward-mode autodiff is handled by When you define To fix this, you can either stick to reverse-mode autodiff transformations (i.e. use For more details, see Custom derivative rules for Python code. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! In JAX, forward-mode autodiff is handled by
jvp
(Jacobian vector product), while reverse-mode autodiff is handled byvjp
(vector Jacobian product).When you define
custom_vjp
, you are defining how reverse-mode autodiff should behave, but you have not defined how forward-mode autodiff should behave.To fix this, you can either stick to reverse-mode autodiff transformations (i.e. use
vjp
,jacrev
,jacobian
, orgrad
, all of which use reverse-mode autodiff) or you can define acustom_jvp
for your function, which defines how it should behave with forward-mode autodiff.For more details, see Custom derivative rules for Python code.