Skip to content

Commit 4cf443b

Browse files
committed
add pullback, rrule and test for svd_vals
1 parent f0a0045 commit 4cf443b

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,19 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
176176
return svd_trunc_pullback
177177
end
178178

179+
function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
180+
USVᴴ = svd_compact(A, alg)
181+
function svd_vals_pullback(ΔS)
182+
ΔA = zero(A)
183+
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
184+
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
185+
end
186+
function svd_pullback(::ZeroTangent) # is this extra definition useful?
187+
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
188+
end
189+
return diagview(USVᴴ[2]), svd_vals_pullback
190+
end
191+
179192
function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg)
180193
Ac = copy_input(left_polar, A)
181194
WP = left_polar!(Ac, WP, alg)

src/pullbacks/svd.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
77
)
88
9-
Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or
9+
Adds the pullback from the SVD of `A` to `ΔA` given the output `USVᴴ` of `svd_compact` or
1010
`svd_full` and the cotangent `ΔUSVᴴ` of `svd_compact`, `svd_full` or `svd_trunc`.
1111
1212
In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with
@@ -201,3 +201,29 @@ function svd_trunc_pullback!(
201201
ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1)
202202
return ΔA
203203
end
204+
205+
"""
206+
svd_vals_pullback!(
207+
ΔA, A, USVᴴ, ΔS, [ind];
208+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
209+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
210+
)
211+
212+
213+
Adds the pullback from the singular values of `A` to `ΔA`, given the output
214+
`USVᴴ` of `svd_compact`, and the cotangent `ΔS` of `svd_vals`.
215+
216+
In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with
217+
magnitude less than `rank_atol` are missing from `S`. For the cotangents, an arbitrary
218+
number of singular vectors or singular values can be missing, i.e. for a matrix `A` with
219+
size `(m, n)`, `diagview(ΔS)` can have length `pS`. In those cases, additionally `ind` is required to
220+
specify which singular vectors and values are present in `ΔS`.
221+
"""
222+
function svd_vals_pullback!(
223+
ΔA, A, USVᴴ, ΔS, ind = Colon();
224+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
225+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
226+
)
227+
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
228+
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
229+
end

test/chainrules.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ include("ad_utils.jl")
1111
for f in
1212
(
1313
:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null,
14-
:eig_full, :eig_trunc, :eigh_full, :eigh_trunc, :svd_compact, :svd_trunc,
14+
:eig_full, :eig_trunc, :eigh_full, :eigh_trunc,
15+
:svd_compact, :svd_trunc, :svd_vals,
1516
:left_polar, :right_polar,
1617
)
1718
copy_f = Symbol(:copy_, f)
@@ -404,6 +405,10 @@ end
404405
copy_svd_compact, A, alg ⊢ NoTangent();
405406
output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol
406407
)
408+
test_rrule(
409+
copy_svd_vals, A, alg ⊢ NoTangent();
410+
output_tangent = diagview(ΔS), atol, rtol
411+
)
407412
for r in 1:4:minmn
408413
truncalg = TruncatedAlgorithm(alg, truncrank(r))
409414
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)

0 commit comments

Comments
 (0)