Skip to content
Discussion options

You must be logged in to vote

Hi! This is a fundamental issue with autodiff of eigenvalue problems: the gradient in the presence of degenerate eigenvalues is mathematically ill-defined (see e.g. #669 and #4646), and so NaN is a reasonable output for an output with no well-defined value.

If you replace your function with one that uses a non-degenerate input matrix, you should see a non-NaN output. Try this, for example:

def parameter(theta):
    key = jax.random.key(0)
    H = jax.random.normal(key, (8, 8)) + theta*jnp.eye(8)
    Q = jax.scipy.linalg.eigh(H)[1]
    return jnp.abs(jnp.linalg.slogdet(Q)[1])

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@hawkinsp
Comment options

Answer selected by miaoL1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants