Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 71 additions & 26 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def eig(x):
w, vl, vr = eig_p.bind(x)
return w, vl, vr

def eigh(x, lower=True, symmetrize_input=True):
def eigh(x, lower=True, symmetrize_input=True, handle_degeneracies=True):
if symmetrize_input:
x = symmetrize(x)
v, w = eigh_p.bind(x, lower=lower)
v, w = eigh_p.bind(x, lower=lower, handle_degeneracies=handle_degeneracies)
return v, w

def lu(x):
Expand Down Expand Up @@ -205,11 +205,12 @@ def eig_batching_rule(batched_args, batch_dims):

# Symmetric/Hermitian eigendecomposition

def eigh_impl(operand, lower):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
def eigh_impl(operand, lower, handle_degeneracies):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
handle_degeneracies=handle_degeneracies)
return v, w

def eigh_translation_rule(c, operand, lower):
def eigh_translation_rule(c, operand, lower, handle_degeneracies):
shape = c.GetShape(operand)
dims = shape.dimensions()
if dims[-1] == 0:
Expand All @@ -219,7 +220,7 @@ def eigh_translation_rule(c, operand, lower):
operand = c.Transpose(operand, list(range(n - 2)) + [n - 1, n - 2])
return c.Eigh(operand)

def eigh_abstract_eval(operand, lower):
def eigh_abstract_eval(operand, lower, handle_degeneracies):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
Expand All @@ -234,7 +235,8 @@ def eigh_abstract_eval(operand, lower):
v, w = operand, operand
return v, w

def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower,
handle_degeneracies):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
v, w, info = syevd_impl(c, operand, lower=lower)
Expand All @@ -245,37 +247,80 @@ def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
_nan_like(c, w))
return c.Tuple(v, w)

def eigh_jvp_rule(primals, tangents, lower):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
# The general solution treating the case of degenerate eigenvalues is
# considerably more complicated. Ambitious readers may refer to the general
# methods below or refer to degenerate perturbation theory in physics.
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
def eigh_jvp_rule(primals, tangents, lower, handle_degeneracies):
# We only do the additional work to properly handle degeneracies if
# handle_degeneracies=True.
#
# The simple case of distinct eigenvalues is classic nondegenerate perurbation
# theory, but also see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
#
# For non-distinct eigenvalues, derivatives are only valid if the eigenvectors
# *also* diagonalize the perturbation a_dot in any degenerate subspace, which
# makes the primal outputs dependent on the tangent input. We use "modal
# expansion method", which is an efficient way to compute derivatives if we
# know the full eigenbasis. For details, see:
# Bernard, M. L. & Bronowicki, A. J. Modal expansion method for
# eigensensitivity with repeated roots. AIAA Journal 32, 1500-1506 (1994).
# Or look up "degenerate perturbation theory" in physics.
#
# Note that this approach does *not* correctly handle high order degeneracies,
# i.e., the case of degenerate eigenvalue derivatives. This means you cannot
# expect to get the right answer if you do higher order differentiation.
# Fixing this would require native support for higher order forwards-mode
# differentiation in JAX (https://github.com/google/jax/pull/1185).

a, = primals
a_dot, = tangents

v, w = eigh_p.bind(symmetrize(a), lower=lower)
a_sym = symmetrize(a)
v, w = eigh_p.bind(
a_sym, lower=lower, handle_degeneracies=handle_degeneracies)

# for complex numbers we need eigenvalues to be full dtype of v, a:
w = w.astype(a.dtype)
eye_n = np.eye(a.shape[-1], dtype=a.dtype)
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = np.reciprocal(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis]) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
epsilon = 10 * np.finfo(a.dtype).resolution

dot = lax.dot if a.ndim == 2 else lax.batch_matmul
vdag_adot_v = dot(dot(_H(v), a_dot), v)
dv = dot(v, np.multiply(Fmat, vdag_adot_v))
dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)
vdag_adot = dot(_H(v), a_dot)
vdag_adot_v = dot(vdag_adot, v)

deltas = w[..., np.newaxis, :] - w[..., np.newaxis]
same_subspace = (abs(deltas) < epsilon
if handle_degeneracies
else np.eye(a.shape[-1], dtype=bool))

if handle_degeneracies:
# Note: this only works for the JVP rule -- we can't transpose it.
# Diagonalize the perturbation in any degenerate subspaces.
v_dot, w_dot = eigh_p.bind(vdag_adot_v * same_subspace, lower=lower,
handle_degeneracies=handle_degeneracies)
# Reorder these into sorted order of the original eigenvalues.
# TODO(shoyer): consider rewriting with an explicit loop over degenerate
# subspaces instead?
v2 = dot(v, v_dot)
w2 = np.einsum('...ij,...jk,...ki->...i', _H(v2), a_sym, v2).real
order = np.argsort(w2, axis=-1)
v = np.take_along_axis(v2, order[..., np.newaxis, :], axis=-1)
dw = np.take_along_axis(w_dot, order, axis=-1)
deltas = w[..., np.newaxis, :] - w[..., np.newaxis]
same_subspace = abs(deltas) < epsilon
else:
dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)

# Note: if not handling degeneracies, this intentionally will result in
# nan/inf if there happen to be degenerate eigenvalues.
Fmat = np.where(same_subspace, 0.0, 1.0 / deltas)
C = Fmat * vdag_adot_v
dv = dot(v, C)

return (v, w), (dv, dw)

def eigh_batching_rule(batched_args, batch_dims, lower):
def eigh_batching_rule(batched_args, batch_dims, lower, handle_degeneracies):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return eigh_p.bind(x, lower=lower), (0, 0)
result = eigh_p.bind(x, lower=lower, handle_degeneracies=handle_degeneracies)
return result, (0, 0)

eigh_p = Primitive('eigh')
eigh_p.multiple_results = True
Expand Down
27 changes: 25 additions & 2 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,37 @@ def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
new_a = (new_a + onp.conj(new_a.T)) / 2
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
RTOL=1e-2
assert onp.max(
onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
onp.testing.assert_allclose(
onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))), new_w,
rtol=RTOL)
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
assert onp.max(
onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) /
onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0)
) < RTOL

def testEighGradDegenerate(self):
a = np.eye(2)
a_dot = a[:, ::-1]
(w, v), (dw, dv) = jvp(np.linalg.eigh, primals=(a,), tangents=(a_dot,))
# correct eigenvectors are 1/sqrt(2) * [[-1, 1], [1, 1]], up to arbitrary
# choice of phase
onp.testing.assert_allclose(abs(v), onp.ones((2, 2)) / onp.sqrt(2))
onp.testing.assert_allclose(w, onp.ones((2,)))
onp.testing.assert_allclose(abs(dv), onp.zeros((2, 2)))
onp.testing.assert_allclose(dw, onp.array([-1, 1]))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: more test cases. Currently nothing is checking eigenvector derivatives in the case of degeneracies if they are non-zero.


def testEighGradBatchDim(self):
rng = jtu.rand_default()
a = rng((2, 3, 3), onp.float32)
a = (a + onp.conj(T(a))) / 2
a_dot = rng((2, 3, 3), onp.float32)
a_dot = (a_dot + onp.conj(T(a_dot))) / 2
_, (dw, dv) = jvp(jsp.linalg.eigh, (a,), (a_dot,))
_, (dw_expected, dv_expected) = jvp(jsp.linalg.eigh, (a[0],), (a_dot[0],))
onp.testing.assert_allclose(dv[0], dv_expected)
onp.testing.assert_allclose(dw[0], dw_expected)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
Expand Down