Removing linear dependencies in orthogonalization #18982
-
Hello, I've started using JAX for my research and have run into the trouble of performing a symmetric orthogonalization and getting what I expect for the derivative.
When computing the derivative, I get for the evals
but for the sqrtm matrix I get
which seems to neglect the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Thanks for the question! Just to make sure understand: you're surprised to see values smaller than |
Beta Was this translation helpful? Give feedback.
Indeed, when differentiating the
jnp.where
, the primal values of the boolean arrayabs(evals) > cutoff
are used to filter both the primal and tangent values. It's the same idea as if we were differentiatinglambda x: x **2 if x > 0 else x
at primal valuex=1.0
but with tangent valuex_dot=-1.0
: we're linearizing around the primal point and so we want to switch based on the primal value only, then have the tangent value follow along (i.e. to go through thex ** 2
function) rather than taking its own path. In this case we're writing ajnp.where
instead of anif
, but it's the same logic (like differentiatinglambda x: jnp.where(x > 0, x ** 2, x)
).So, super concretely, when we differentiate
l…