-
I've encountered a challenge calculating the gradient value after eigenvectors were computed.
No matter what "theta value " is given, the result is always "nan". |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi! This is a fundamental issue with autodiff of eigenvalue problems: the gradient in the presence of degenerate eigenvalues is mathematically ill-defined (see e.g. #669 and #4646), and so If you replace your function with one that uses a non-degenerate input matrix, you should see a non-NaN output. Try this, for example: def parameter(theta):
key = jax.random.key(0)
H = jax.random.normal(key, (8, 8)) + theta*jnp.eye(8)
Q = jax.scipy.linalg.eigh(H)[1]
return jnp.abs(jnp.linalg.slogdet(Q)[1]) |
Beta Was this translation helpful? Give feedback.
Hi! This is a fundamental issue with autodiff of eigenvalue problems: the gradient in the presence of degenerate eigenvalues is mathematically ill-defined (see e.g. #669 and #4646), and so
NaN
is a reasonable output for an output with no well-defined value.If you replace your function with one that uses a non-degenerate input matrix, you should see a non-NaN output. Try this, for example: