Skip to content

Commit c7537c0

Browse files
authored
Add SVD factorization rrule (#31)
This adds `rrule`s for the SVD factorization as well as an accompanying `rrule` for `getproperty` on `SVD` objects. The definitions are ported from Nabla.
1 parent 45c79e9 commit c7537c0

File tree

5 files changed

+108
-1
lines changed

5 files changed

+108
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
99

1010
[compat]
1111
Cassette = "^0.2"
12-
FDM = "^0.4"
12+
FDM = "^0.5"
1313
julia = "^1.0"
1414

1515
[extras]

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include("rules/broadcast.jl")
1515
include("rules/linalg/dense.jl")
1616
include("rules/linalg/diagonal.jl")
1717
include("rules/linalg/symmetric.jl")
18+
include("rules/linalg/factorization.jl")
1819
include("rules/blas.jl")
1920
include("rules/nanmath.jl")
2021
include("rules/specialfunctions.jl")

src/rules/linalg/factorization.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#####
2+
##### `svd`
3+
#####
4+
5+
function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
6+
F = svd(X)
7+
∂X = Rule() do::NamedTuple{(:U,:S,:V)}
8+
svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)
9+
end
10+
return F, ∂X
11+
end
12+
13+
function rrule(::typeof(getproperty), F::SVD, x::Symbol)
14+
if x === :U
15+
return F.U, (Rule(Ȳ->(U=Ȳ, S=zero(F.S), V=zero(F.V))), DNERule())
16+
elseif x === :S
17+
return F.S, (Rule(Ȳ->(U=zero(F.U), S=Ȳ, V=zero(F.V))), DNERule())
18+
elseif x === :V
19+
return F.V, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=Ȳ)), DNERule())
20+
elseif x === :Vt
21+
return F.Vt, (Rule(Ȳ->(U=zero(F.U), S=zero(F.S), V=')), DNERule())
22+
end
23+
end
24+
25+
function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix)
26+
# Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default
27+
U = USV.U
28+
s = USV.S
29+
V = USV.V
30+
Vt = USV.Vt
31+
32+
k = length(s)
33+
T = eltype(s)
34+
F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k]
35+
36+
# We do a lot of matrix operations here, so we'll try to be memory-friendly and do
37+
# as many of the computations in-place as possible. Benchmarking shows that the in-
38+
# place functions here are significantly faster than their out-of-place, naively
39+
# implemented counterparts, and allocate no additional memory.
40+
Ut = U'
41+
FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU)
42+
FVᵀV̄ = _mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV)
43+
ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ
44+
ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ
45+
46+
S = Diagonal(s)
47+
= Diagonal(s̄)
48+
49+
= _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt
50+
_add!(Ā, U ** Vt)
51+
_add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \') * ImVVᵀ))
52+
53+
return
54+
end
55+
56+
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
57+
k = size(X, 1)
58+
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
59+
if i == j
60+
X[i,i] = zero(T)
61+
else
62+
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
63+
end
64+
end
65+
X
66+
end
67+
68+
function _eyesubx!(X::AbstractMatrix{T}) where T<:Real
69+
n, m = size(X)
70+
@inbounds for j = 1:m, i = 1:n
71+
X[i,j] = (i == j) - X[i,j]
72+
end
73+
X
74+
end
75+
76+
function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
77+
@inbounds for i = eachindex(X, Y)
78+
X[i] += Y[i]
79+
end
80+
X
81+
end

test/rules/linalg/factorization.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testset "Factorizations" begin
2+
@testset "svd" begin
3+
rng = MersenneTwister(2)
4+
for n in [4, 6, 10], m in [3, 5, 10]
5+
X = randn(rng, n, m)
6+
F, dX = rrule(svd, X)
7+
for p in [:U, :S, :V, :Vt]
8+
Y, (dF, dp) = rrule(getproperty, F, p)
9+
@test dp isa ChainRules.DNERule
10+
= randn(rng, size(Y)...)
11+
X̄_ad = dX(dF(Ȳ))
12+
X̄_fd = j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)
13+
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
14+
end
15+
end
16+
@testset "Helper functions" begin
17+
X = randn(rng, 10, 10)
18+
Y = randn(rng, 10, 10)
19+
@test ChainRules._mulsubtrans!(copy(X), Y) Y .* (X - X')
20+
@test ChainRules._eyesubx!(copy(X)) I - X
21+
@test ChainRules._add!(copy(X), Y) X + Y
22+
end
23+
end
24+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include("test_util.jl")
1818
include(joinpath("rules", "linalg", "dense.jl"))
1919
include(joinpath("rules", "linalg", "diagonal.jl"))
2020
include(joinpath("rules", "linalg", "symmetric.jl"))
21+
include(joinpath("rules", "linalg", "factorization.jl"))
2122
end
2223
include(joinpath("rules", "broadcast.jl"))
2324
include(joinpath("rules", "blas.jl"))

0 commit comments

Comments
 (0)