-
Below is an exemplary matrix from the model for which the decomposition yields different signs for the eigenvectors depending on the backend. from jax import numpy as jnp
info_prop_inv = jnp.array([[ 0.86263333, -0.25096436, -0.25096437, -0.15021721,
-0.004183 , -0.00377802, -0.00377802, -0.00337444],
[-0.25096436, 0.86263333, -0.15021721, -0.25096437,
-0.00377802, -0.004183 , -0.00337444, -0.00377802],
[-0.25096437, -0.15021721, 0.86263333, -0.25096436,
-0.00377802, -0.00337444, -0.004183 , -0.00377802],
[-0.15021721, -0.25096437, -0.25096436, 0.86263333,
-0.00337444, -0.00377802, -0.00377802, -0.004183 ],
[-0.004183 , -0.00377802, -0.00377802, -0.00337444,
0.8614207 , -0.25057809, -0.25057809, -0.15001035],
[-0.00377802, -0.004183 , -0.00337444, -0.00377802,
-0.25057809, 0.8614207 , -0.15001035, -0.25057809],
[-0.00377802, -0.00337444, -0.004183 , -0.00377802,
-0.25057809, -0.15001035, 0.8614207 , -0.25057809],
[-0.00337444, -0.00377802, -0.00377802, -0.004183 ,
-0.15001035, -0.25057809, -0.25057809, 0.8614207 ]])
v, w = jnp.linalg.eigh(info_prop_inv) On the CPU, the eigenvectors have the following signs:
On the GPU, the eigenvectors sometimes have a different sign:
BackgroundPart of my model in JAX relies on having access to the eigenvectors of matrices. These matrices are constant within the model and fully determined by a simple function. They are intricately involved in applying the model and the parameters of the mode depend on the sign of the eigenvectors. For saving the model, I do not store the resulting eigenvectors of the matrix but instead only save the few parameters of the function with which I can generate the matrix and thus the eigenvectors of it. However, since the eigenvectors are different depending on the backend, the parameters of the model when learned with the code bein compiled for the GPU are incompatible with the code being compiled for the CPU. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Unfortunately, I don't think it's possible for JAX to guarantee deterministic eigendecompositions, even on CPU. JAX's Even if there are no degeneracies, eigenvectors are only uniquely defined up a global phase. Each platform has the freedom to pick its own way of calculating eigenvectors, in whichever manner is most convenient. You will even find that different LAPACK distributions on CPU (e.g., Intel MKL vs OpenBLAS) may calculated eigenvectors different. You certainly could pick your own convention for normalizing eigenvectors across platforms, but my main suggestion is to look for ways to reformulate your model such that it is invariant to this global phase. In my experience, this is a sign that you're doing something wrong in your model, because your model is no longer a mathematically well defined function. |
Beta Was this translation helpful? Give feedback.
Unfortunately, I don't think it's possible for JAX to guarantee deterministic eigendecompositions, even on CPU.
JAX's
eigh
relies on LAPACK'ssyevd
and GPU implementations by Nvidia and AMD:https://github.com/google/jax/blob/53318a2a7a644e5ed1ac657f408e31eeb1fe5a0d/jax/_src/lax/linalg.py#L562-L576
Even if there are no degeneracies, eigenvectors are only uniquely defined up a global phase. Each platform has the freedom to pick its own way of calculating eigenvectors, in whichever manner is most convenient. You will even find that different LAPACK distributions on CPU (e.g., Intel MKL vs OpenBLAS) may calculated eigenvectors different.
You certainly could pick your own convention for norma…