Skip to content
Discussion options

You must be logged in to vote

Unfortunately, I don't think it's possible for JAX to guarantee deterministic eigendecompositions, even on CPU.

JAX's eigh relies on LAPACK's syevd and GPU implementations by Nvidia and AMD:
https://github.com/google/jax/blob/53318a2a7a644e5ed1ac657f408e31eeb1fe5a0d/jax/_src/lax/linalg.py#L562-L576

Even if there are no degeneracies, eigenvectors are only uniquely defined up a global phase. Each platform has the freedom to pick its own way of calculating eigenvectors, in whichever manner is most convenient. You will even find that different LAPACK distributions on CPU (e.g., Intel MKL vs OpenBLAS) may calculated eigenvectors different.

You certainly could pick your own convention for norma…

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@Edenhofer
Comment options

@Edenhofer
Comment options

@shoyer
Comment options

shoyer Dec 22, 2021
Collaborator

@jakevdp
Comment options

@Edenhofer
Comment options

Answer selected by Edenhofer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants