Skip to content

Commit cea4178

Browse files
authored
[AD] add chainrules support for svd_vals and eig(h)_vals and add diagonal (#107)
* add and export `diagonal` and `diagview` * add `pullback`, `rrule` and test for `svd_vals` * add `pullback1, `rrule` and test for `eig(h)_vals` * add zygote tests * also update mooncake rules * add exports
1 parent e72dca4 commit cea4178

File tree

8 files changed

+172
-26
lines changed

8 files changed

+172
-26
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ for eig in (:eig, :eigh)
9595
eig_t! = Symbol(eig, "_trunc!")
9696
eig_t_pb = Symbol(eig, "_trunc_pullback")
9797
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
98+
eig_v = Symbol(eig, "_vals")
99+
eig_v! = Symbol(eig_v, "!")
100+
eig_v_pb = Symbol(eig_v, "_pullback")
101+
eig_v_pb! = Symbol(eig_v_pb, "!")
102+
98103
@eval begin
99104
function ChainRulesCore.rrule(::typeof($eig_f!), A, DV, alg)
100105
Ac = copy_input($eig_f, A)
@@ -131,6 +136,18 @@ for eig in (:eig, :eigh)
131136
end
132137
return $eig_t_pb
133138
end
139+
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
140+
DV = $eig_f(A, alg)
141+
function $eig_v_pb(ΔD)
142+
ΔA = zero(A)
143+
MatrixAlgebraKit.$eig_v_pb!(ΔA, A, DV, unthunk(ΔD))
144+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
145+
end
146+
function $eig_v_pb(::ZeroTangent) # is this extra definition useful?
147+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
148+
end
149+
return diagview(DV[1]), $eig_v_pb
150+
end
134151
end
135152
end
136153

@@ -176,6 +193,19 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
176193
return svd_trunc_pullback
177194
end
178195

196+
function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
197+
USVᴴ = svd_compact(A, alg)
198+
function svd_vals_pullback(ΔS)
199+
ΔA = zero(A)
200+
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
201+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
202+
end
203+
function svd_pullback(::ZeroTangent) # is this extra definition useful?
204+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
205+
end
206+
return diagview(USVᴴ[2]), svd_vals_pullback
207+
end
208+
179209
function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg)
180210
Ac = copy_input(left_polar, A)
181211
WP = left_polar!(Ac, WP, alg)

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ using MatrixAlgebraKit
66
using MatrixAlgebraKit: inv_safe, diagview, copy_input
77
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
88
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9-
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
9+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
10+
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
1011
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11-
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
12+
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
1213
using LinearAlgebra
1314

1415

@@ -122,8 +123,8 @@ for (f!, f, pb, adj) in (
122123
end
123124

124125
for (f!, f, f_full, pb, adj) in (
125-
(:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_vals_adjoint),
126-
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_vals_adjoint),
126+
(:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint),
127+
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
127128
)
128129
@eval begin
129130
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in (
136137
copy!(D, diagview(DV[1]))
137138
V = DV[2]
138139
function $adj(::NoRData)
139-
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
140+
$pb(dA, A, DV, dD)
140141
MatrixAlgebraKit.zero!(dD)
141142
return NoRData(), NoRData(), NoRData(), NoRData()
142143
end
@@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in (
153154
output_codual = CoDual(output, Mooncake.zero_tangent(output))
154155
function $adj(::NoRData)
155156
D, dD = arrayify(output_codual)
156-
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
157+
$pb(dA, A, DV, dD)
157158
MatrixAlgebraKit.zero!(dD)
158159
return NoRData(), NoRData(), NoRData()
159160
end
@@ -272,10 +273,10 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
272273
# compute primal
273274
A, dA = arrayify(A_dA)
274275
S, dS = arrayify(S_dS)
275-
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
276-
copy!(S, diagview(nS))
276+
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
277+
copy!(S, diagview(USVᴴ[2]))
277278
function svd_vals_adjoint(::NoRData)
278-
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
279+
svd_vals_pullback!(dA, A, USVᴴ, dS)
279280
MatrixAlgebraKit.zero!(dS)
280281
return NoRData(), NoRData(), NoRData(), NoRData()
281282
end
@@ -286,15 +287,16 @@ end
286287
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
287288
# compute primal
288289
A, dA = arrayify(A_dA)
289-
U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
290+
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
290291
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
291292
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
292293
# pass). For many types this is done automatically when the forward step returns, but
293294
# not for nested structs with various fields (like Diagonal{Complex})
294-
S_codual = CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(diagview(S))))
295+
S = diagview(USVᴴ[2])
296+
S_codual = CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S)))
295297
function svd_vals_adjoint(::NoRData)
296298
S, dS = arrayify(S_codual)
297-
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
299+
svd_vals_pullback!(dA, A, USVᴴ, dS)
298300
MatrixAlgebraKit.zero!(dS)
299301
return NoRData(), NoRData(), NoRData()
300302
end

src/MatrixAlgebraKit.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using LinearAlgebra: UpperTriangular, LowerTriangular
1010
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt
1111

1212
export isisometric, isunitary, ishermitian, isantihermitian
13+
export diagview, diagonal
1314

1415
export project_hermitian, project_antihermitian, project_isometric
1516
export project_hermitian!, project_antihermitian!, project_isometric!
@@ -62,8 +63,9 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
6263
Expr(
6364
:public, :left_polar_pullback!, :right_polar_pullback!,
6465
:qr_pullback!, :qr_null_pullback!, :lq_pullback!, :lq_null_pullback!,
65-
:eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!, :eigh_trunc_pullback!,
66-
:svd_pullback!, :svd_trunc_pullback!
66+
:eig_pullback!, :eig_trunc_pullback!, :eig_vals_pullback!,
67+
:eigh_pullback!, :eigh_trunc_pullback!, :eigh_vals_pullback!,
68+
:svd_pullback!, :svd_trunc_pullback!, :svd_vals_pullback!
6769
)
6870
)
6971
eval(Expr(:public, :is_left_isometric, :is_right_isometric))

src/common/view.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
11
# diagind: provided by LinearAlgebra.jl
2+
@doc """
3+
diagview(D)
4+
5+
Return a view of the diagonal elements of a matrix `D`.
6+
7+
See also [`diagonal`](@ref).
8+
""" diagview
9+
210
diagview(D::Diagonal) = D.diag
311
diagview(D::AbstractMatrix) = view(D, diagind(D))
412

13+
@doc """
14+
diagonal(v)
15+
16+
Construct a diagonal matrix view for the given diagonal vector.
17+
18+
See also [`diagview`](@ref).
19+
""" diagonal
20+
21+
diagonal(v::AbstractVector) = Diagonal(v)
22+
523
# triangularind
624
function lowertriangularind(A::AbstractMatrix)
725
Base.require_one_based_indexing(A)

src/pullbacks/eig.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,27 @@ function eig_trunc_pullback!(
151151
end
152152
return ΔA
153153
end
154+
155+
"""
156+
eig_vals_pullback!(
157+
ΔA, A, DV, ΔD, [ind];
158+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
159+
)
160+
161+
Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output
162+
`DV` of `eig_full` and the cotangent `ΔD` of `eig_vals`.
163+
164+
In particular, it is assumed that `A V * D * inv(V)` with thus `size(A) == size(V) == size(D)`
165+
and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e.
166+
for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases,
167+
additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`.
168+
By default, it is assumed that all eigenvectors and eigenvalues are present.
169+
"""
170+
function eig_vals_pullback!(
171+
ΔA, A, DV, ΔD, ind = Colon();
172+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
173+
)
174+
175+
ΔDV = (diagonal(ΔD), nothing)
176+
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
177+
end

src/pullbacks/eigh.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,27 @@ function eigh_trunc_pullback!(
141141
end
142142
return ΔA
143143
end
144+
145+
"""
146+
eigh_vals_pullback!(
147+
ΔA, A, DV, ΔD, [ind];
148+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
149+
)
150+
151+
Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output
152+
`DV` of `eigh_full` and the cotangent `ΔD` of `eig_vals`.
153+
154+
In particular, it is assumed that `A V * D * inv(V)` with thus `size(A) == size(V) == size(D)`
155+
and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e.
156+
for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases,
157+
additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`.
158+
By default, it is assumed that all eigenvectors and eigenvalues are present.
159+
"""
160+
function eigh_vals_pullback!(
161+
ΔA, A, DV, ΔD, ind = Colon();
162+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
163+
)
164+
165+
ΔDV = (diagonal(ΔD), nothing)
166+
return eigh_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
167+
end

src/pullbacks/svd.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
77
)
88
9-
Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or
9+
Adds the pullback from the SVD of `A` to `ΔA` given the output `USVᴴ` of `svd_compact` or
1010
`svd_full` and the cotangent `ΔUSVᴴ` of `svd_compact`, `svd_full` or `svd_trunc`.
1111
1212
In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with
@@ -201,3 +201,29 @@ function svd_trunc_pullback!(
201201
ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1)
202202
return ΔA
203203
end
204+
205+
"""
206+
svd_vals_pullback!(
207+
ΔA, A, USVᴴ, ΔS, [ind];
208+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
209+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
210+
)
211+
212+
213+
Adds the pullback from the singular values of `A` to `ΔA`, given the output
214+
`USVᴴ` of `svd_compact`, and the cotangent `ΔS` of `svd_vals`.
215+
216+
In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with
217+
magnitude less than `rank_atol` are missing from `S`. For the cotangents, an arbitrary
218+
number of singular vectors or singular values can be missing, i.e. for a matrix `A` with
219+
size `(m, n)`, `diagview(ΔS)` can have length `pS`. In those cases, additionally `ind` is required to
220+
specify which singular vectors and values are present in `ΔS`.
221+
"""
222+
function svd_vals_pullback!(
223+
ΔA, A, USVᴴ, ΔS, ind = Colon();
224+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
225+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
226+
)
227+
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
228+
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
229+
end

test/chainrules.jl

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,23 @@ include("ad_utils.jl")
1111
for f in
1212
(
1313
:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null,
14-
:eig_full, :eig_trunc, :eigh_full, :eigh_trunc, :svd_compact, :svd_trunc,
14+
:eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
15+
:svd_compact, :svd_trunc, :svd_vals,
1516
:left_polar, :right_polar,
1617
)
1718
copy_f = Symbol(:copy_, f)
1819
f! = Symbol(f, '!')
20+
_hermitian = startswith(string(f), "eigh")
1921
@eval begin
2022
function $copy_f(input, alg)
21-
if $f === eigh_full || $f === eigh_trunc
23+
if $_hermitian
2224
input = (input + input') / 2
2325
end
2426
return $f(input, alg)
2527
end
2628
function ChainRulesCore.rrule(::typeof($copy_f), input, alg)
2729
output = MatrixAlgebraKit.initialize_output($f!, input, alg)
28-
if $f === eigh_full || $f === eigh_trunc
30+
if $_hermitian
2931
input = (input + input') / 2
3032
else
3133
input = copy(input)
@@ -228,12 +230,13 @@ end
228230
ΔD2 = Diagonal(randn(rng, complex(T), m))
229231
for alg in (LAPACK_Simple(), LAPACK_Expert())
230232
test_rrule(
231-
copy_eig_full, A, alg NoTangent();
232-
output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol
233+
copy_eig_full, A, alg NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
233234
)
234235
test_rrule(
235-
copy_eig_full, A, alg NoTangent();
236-
output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol
236+
copy_eig_full, A, alg NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
237+
)
238+
test_rrule(
239+
copy_eig_vals, A, alg NoTangent(); output_tangent = diagview(ΔD), atol, rtol
237240
)
238241
for r in 1:4:m
239242
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
@@ -284,6 +287,10 @@ end
284287
config, last eig_full, A;
285288
output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
286289
)
290+
test_rrule(
291+
config, eig_vals, A;
292+
output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
293+
)
287294
end
288295

289296
@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32)
@@ -304,12 +311,13 @@ end
304311
)
305312
# copy_eigh_full includes a projector onto the Hermitian part of the matrix
306313
test_rrule(
307-
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV),
308-
atol = atol, rtol = rtol
314+
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
309315
)
310316
test_rrule(
311-
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV),
312-
atol = atol, rtol = rtol
317+
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
318+
)
319+
test_rrule(
320+
copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol
313321
)
314322
for r in 1:4:m
315323
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
@@ -361,6 +369,10 @@ end
361369
config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A;
362370
output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
363371
)
372+
test_rrule(
373+
config, eigh_vals ∘ Matrix ∘ Hermitian, A;
374+
output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
375+
)
364376
eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...)
365377
for r in 1:4:m
366378
trunc = truncrank(r; by = real)
@@ -404,6 +416,10 @@ end
404416
copy_svd_compact, A, alg ⊢ NoTangent();
405417
output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol
406418
)
419+
test_rrule(
420+
copy_svd_vals, A, alg ⊢ NoTangent();
421+
output_tangent = diagview(ΔS), atol, rtol
422+
)
407423
for r in 1:4:minmn
408424
truncalg = TruncatedAlgorithm(alg, truncrank(r))
409425
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
@@ -451,6 +467,10 @@ end
451467
output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol,
452468
rrule_f = rrule_via_ad, check_inferred = false
453469
)
470+
test_rrule(
471+
config, svd_vals, A;
472+
output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
473+
)
454474
for r in 1:4:minmn
455475
trunc = truncrank(r)
456476
ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc)

0 commit comments

Comments
 (0)