Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 80 additions & 26 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def eig(x):
w, vl, vr = eig_p.bind(x)
return w, vl, vr

def eigh(x, lower=True, symmetrize_input=True):
def eigh(x, lower=True, symmetrize_input=True, handle_degeneracies=True):
if symmetrize_input:
x = symmetrize(x)
v, w = eigh_p.bind(x, lower=lower)
v, w = eigh_p.bind(x, lower=lower, handle_degeneracies=handle_degeneracies)
return v, w

def lu(x):
Expand Down Expand Up @@ -205,11 +205,12 @@ def eig_batching_rule(batched_args, batch_dims):

# Symmetric/Hermitian eigendecomposition

def eigh_impl(operand, lower):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower)
def eigh_impl(operand, lower, handle_degeneracies):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
handle_degeneracies=handle_degeneracies)
return v, w

def eigh_translation_rule(c, operand, lower):
def eigh_translation_rule(c, operand, lower, handle_degeneracies):
shape = c.GetShape(operand)
dims = shape.dimensions()
if dims[-1] == 0:
Expand All @@ -219,7 +220,7 @@ def eigh_translation_rule(c, operand, lower):
operand = c.Transpose(operand, list(range(n - 2)) + [n - 1, n - 2])
return c.Eigh(operand)

def eigh_abstract_eval(operand, lower):
def eigh_abstract_eval(operand, lower, handle_degeneracies):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
Expand All @@ -234,7 +235,8 @@ def eigh_abstract_eval(operand, lower):
v, w = operand, operand
return v, w

def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower,
handle_degeneracies):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
v, w, info = syevd_impl(c, operand, lower=lower)
Expand All @@ -245,37 +247,89 @@ def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
_nan_like(c, w))
return c.Tuple(v, w)

def eigh_jvp_rule(primals, tangents, lower):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
# The general solution treating the case of degenerate eigenvalues is
# considerably more complicated. Ambitious readers may refer to the general
# methods below or refer to degenerate perturbation theory in physics.
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
def eigh_jvp_rule(primals, tangents, lower, handle_degeneracies):
# We only do the additional work to properly handle degeneracies if
# handle_degeneracies=True.
#
# The simple case of distinct eigenvalues is classic nondegenerate perurbation
# theory, but also see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
#
# For non-distinct eigenvalues, derivatives are only valid if the eigenvectors
# *also* diagonalize the perturbation a_dot in any degenerate subspace, which
# makes the primal outputs dependent on the tangent input. We use "modal
# expansion method", which is an efficient way to compute derivatives if we
# know the full eigenbasis. For details, see:
# Bernard, M. L. & Bronowicki, A. J. Modal expansion method for
# eigensensitivity with repeated roots. AIAA Journal 32, 1500-1506 (1994).
# Or look up "degenerate perturbation theory" in physics.
#
# Note that this approach does *not* correctly handle high order degeneracies,
# i.e., the case of degenerate eigenvalue derivatives. This means you cannot
# expect to get the right answer if you do higher order differentiation.
# Fixing this would require native support for higher order forwards-mode
# differentiation in JAX (https://github.com/google/jax/pull/1185).

a, = primals
a_dot, = tangents

v, w = eigh_p.bind(symmetrize(a), lower=lower)
a_sym = symmetrize(a)
v, w = eigh_p.bind(
a_sym, lower=lower, handle_degeneracies=handle_degeneracies)

# for complex numbers we need eigenvalues to be full dtype of v, a:
w = w.astype(a.dtype)
eye_n = np.eye(a.shape[-1], dtype=a.dtype)
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = np.reciprocal(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis]) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
epsilon = 10 * np.finfo(a.dtype).resolution

deltas = w[..., np.newaxis, :] - w[..., np.newaxis]
same_subspace = (abs(deltas) < epsilon
if handle_degeneracies
else np.eye(a.shape[-1], dtype=bool))
# Note: if not handling degeneracies, this intentionally will result in
# nan/inf if there happen to be degenerate eigenvalues.
Fmat = np.where(same_subspace, 0.0, 1.0 / deltas)

dot = lax.dot if a.ndim == 2 else lax.batch_matmul
vdag_adot_v = dot(dot(_H(v), a_dot), v)
dv = dot(v, np.multiply(Fmat, vdag_adot_v))
dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)
vdag_adot = dot(_H(v), a_dot)
vdag_adot_v = dot(vdag_adot, v)
C = Fmat * vdag_adot_v
dv_non_degenerate = dot(v, C)

if handle_degeneracies:
# Diagonalize the perturbation in any degenerate subspaces.
A, dw = eigh_p.bind(vdag_adot_v * same_subspace, lower=lower,
handle_degeneracies=handle_degeneracies)

# Reorder these into sorted order of the original eigenvalues.
# TODO(shoyer): consider rewriting with an explicit loop over degenerate
# subspaces instead?
v2 = dot(v, A)
w2 = np.einsum('ij,jk,ki->i', _H(v2), a_sym, v2).real
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: handle batching here.

order = np.argsort(w2)
A = A[..., :, order]
dw = dw[..., order]

deltas_dot = dw[..., np.newaxis, :] - dw[..., np.newaxis]
same_dot_subspace = abs(deltas_dot) < epsilon
# If there are still some degenerate eigenvalue derivatives, the choice of
# basis is arbitrary (up to first order perturbations), so it's safe to set
# these terms in C_dot to 0.
Fmat_dot = np.where(same_dot_subspace, 0.0, 1.0 / deltas_dot)
vdag_adot_dv_non_degen = dot(vdag_adot, dv_non_degenerate)
C_dot = Fmat_dot * vdag_adot_dv_non_degen
dv = dot(v, np.where(same_subspace, C_dot, C))
v = dot(v, A)
else:
dv = dv_non_degenerate
dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)

return (v, w), (dv, dw)

def eigh_batching_rule(batched_args, batch_dims, lower):
def eigh_batching_rule(batched_args, batch_dims, lower, handle_degeneracies):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return eigh_p.bind(x, lower=lower), (0, 0)
result = eigh_p.bind(x, lower=lower, handle_degeneracies=handle_degeneracies)
return result, (0, 0)

eigh_p = Primitive('eigh')
eigh_p.multiple_results = True
Expand Down
17 changes: 15 additions & 2 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,27 @@ def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
new_a = (new_a + onp.conj(new_a.T)) / 2
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
RTOL=1e-2
assert onp.max(
onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
onp.testing.assert_allclose(
onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))), new_w,
rtol=RTOL)
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
assert onp.max(
onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) /
onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0)
) < RTOL

def testEighGradDegenerate(self):
rng = jtu.rand_default()
a = np.eye(2)
a_dot = a[:, ::-1]
(w, v), (dw, dv) = jvp(np.linalg.eigh, primals=(a,), tangents=(a_dot,))
# correct eigenvectors are 1/sqrt(2) * [[-1, 1], [1, 1]], up to arbitrary
# choice of phase
onp.testing.assert_allclose(abs(v), onp.ones((2, 2)) / onp.sqrt(2))
onp.testing.assert_allclose(w, onp.ones((2,)))
onp.testing.assert_allclose(abs(dv), onp.zeros((2, 2)))
onp.testing.assert_allclose(dw, onp.array([-1, 1]))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: more test cases. Currently nothing is checking eigenvector derivatives in the case of degeneracies if they are non-zero.


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
Expand Down