Skip to content

Commit b38a916

Browse files
committed
add pullback1, rrule and test for eig(h)_vals`
1 parent 4cf443b commit b38a916

File tree

4 files changed

+79
-11
lines changed

4 files changed

+79
-11
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ for eig in (:eig, :eigh)
9595
eig_t! = Symbol(eig, "_trunc!")
9696
eig_t_pb = Symbol(eig, "_trunc_pullback")
9797
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
98+
eig_v = Symbol(eig, "_vals")
99+
eig_v! = Symbol(eig_v, "!")
100+
eig_v_pb = Symbol(eig_v, "_pullback")
101+
eig_v_pb! = Symbol(eig_v_pb, "!")
102+
98103
@eval begin
99104
function ChainRulesCore.rrule(::typeof($eig_f!), A, DV, alg)
100105
Ac = copy_input($eig_f, A)
@@ -131,6 +136,18 @@ for eig in (:eig, :eigh)
131136
end
132137
return $eig_t_pb
133138
end
139+
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
140+
DV = $eig_f(A, alg)
141+
function $eig_v_pb(ΔD)
142+
ΔA = zero(A)
143+
MatrixAlgebraKit.$eig_v_pb!(ΔA, A, DV, unthunk(ΔD))
144+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
145+
end
146+
function $eig_v_pb(::ZeroTangent) # is this extra definition useful?
147+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
148+
end
149+
return diagview(DV[1]), $eig_v_pb
150+
end
134151
end
135152
end
136153

src/pullbacks/eig.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,27 @@ function eig_trunc_pullback!(
151151
end
152152
return ΔA
153153
end
154+
155+
"""
156+
eig_vals_pullback!(
157+
ΔA, A, DV, ΔD, [ind];
158+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
159+
)
160+
161+
Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output
162+
`DV` of `eig_full` and the cotangent `ΔD` of `eig_vals`.
163+
164+
In particular, it is assumed that `A V * D * inv(V)` with thus `size(A) == size(V) == size(D)`
165+
and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e.
166+
for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases,
167+
additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`.
168+
By default, it is assumed that all eigenvectors and eigenvalues are present.
169+
"""
170+
function eig_vals_pullback!(
171+
ΔA, A, DV, ΔD, ind = Colon();
172+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
173+
)
174+
175+
ΔDV = (diagonal(ΔD), nothing)
176+
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
177+
end

src/pullbacks/eigh.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,27 @@ function eigh_trunc_pullback!(
141141
end
142142
return ΔA
143143
end
144+
145+
"""
146+
eigh_vals_pullback!(
147+
ΔA, A, DV, ΔD, [ind];
148+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
149+
)
150+
151+
Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output
152+
`DV` of `eigh_full` and the cotangent `ΔD` of `eig_vals`.
153+
154+
In particular, it is assumed that `A V * D * inv(V)` with thus `size(A) == size(V) == size(D)`
155+
and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e.
156+
for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases,
157+
additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`.
158+
By default, it is assumed that all eigenvectors and eigenvalues are present.
159+
"""
160+
function eigh_vals_pullback!(
161+
ΔA, A, DV, ΔD, ind = Colon();
162+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
163+
)
164+
165+
ΔDV = (diagonal(ΔD), nothing)
166+
return eigh_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
167+
end

test/chainrules.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,23 @@ include("ad_utils.jl")
1111
for f in
1212
(
1313
:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null,
14-
:eig_full, :eig_trunc, :eigh_full, :eigh_trunc,
14+
:eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
1515
:svd_compact, :svd_trunc, :svd_vals,
1616
:left_polar, :right_polar,
1717
)
1818
copy_f = Symbol(:copy_, f)
1919
f! = Symbol(f, '!')
20+
_hermitian = startswith(string(f), "eigh")
2021
@eval begin
2122
function $copy_f(input, alg)
22-
if $f === eigh_full || $f === eigh_trunc
23+
if $_hermitian
2324
input = (input + input') / 2
2425
end
2526
return $f(input, alg)
2627
end
2728
function ChainRulesCore.rrule(::typeof($copy_f), input, alg)
2829
output = MatrixAlgebraKit.initialize_output($f!, input, alg)
29-
if $f === eigh_full || $f === eigh_trunc
30+
if $_hermitian
3031
input = (input + input') / 2
3132
else
3233
input = copy(input)
@@ -229,12 +230,13 @@ end
229230
ΔD2 = Diagonal(randn(rng, complex(T), m))
230231
for alg in (LAPACK_Simple(), LAPACK_Expert())
231232
test_rrule(
232-
copy_eig_full, A, alg NoTangent();
233-
output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol
233+
copy_eig_full, A, alg NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
234234
)
235235
test_rrule(
236-
copy_eig_full, A, alg NoTangent();
237-
output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol
236+
copy_eig_full, A, alg NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
237+
)
238+
test_rrule(
239+
copy_eig_vals, A, alg NoTangent(); output_tangent = diagview(ΔD), atol, rtol
238240
)
239241
for r in 1:4:m
240242
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
@@ -305,12 +307,13 @@ end
305307
)
306308
# copy_eigh_full includes a projector onto the Hermitian part of the matrix
307309
test_rrule(
308-
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV),
309-
atol = atol, rtol = rtol
310+
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
310311
)
311312
test_rrule(
312-
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV),
313-
atol = atol, rtol = rtol
313+
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
314+
)
315+
test_rrule(
316+
copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol
314317
)
315318
for r in 1:4:m
316319
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))

0 commit comments

Comments
 (0)