Skip to content

Commit a1ffedc

Browse files
authored
Mooncake reverse rules (#85)
* Mooncake reverse rules
1 parent 58846bb commit a1ffedc

File tree

8 files changed

+944
-45
lines changed

8 files changed

+944
-45
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1313
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
1414
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
15+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1516

1617
[extensions]
1718
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
1819
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
1920
MatrixAlgebraKitCUDAExt = "CUDA"
2021
MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra"
2122
MatrixAlgebraKitGenericSchurExt = "GenericSchur"
23+
MatrixAlgebraKitMooncakeExt = "Mooncake"
2224

2325
[compat]
2426
AMDGPU = "2"
@@ -30,6 +32,7 @@ GenericLinearAlgebra = "0.3.19"
3032
GenericSchur = "0.5.6"
3133
JET = "0.9, 0.10"
3234
LinearAlgebra = "1"
35+
Mooncake = "0.4.174"
3336
SafeTestsets = "0.1"
3437
StableRNGs = "1"
3538
Test = "1"
@@ -43,11 +46,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4346
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4447
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4548
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
49+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
4650
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4751
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4852
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4953
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5054
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5155

5256
[targets]
53-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
57+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 333 additions & 0 deletions
Large diffs are not rendered by default.

src/pullbacks/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and
55
cotangent `ΔWP` of `left_polar(A)`.
66
"""
7-
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP)
7+
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
88
# Extract the Polar components
99
W, P = WP
1010

@@ -34,7 +34,7 @@ end
3434
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ`
3535
and cotangent `ΔPWᴴ` of `right_polar(A)`.
3636
"""
37-
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ)
37+
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...)
3838
# Extract the Polar components
3939
P, Wᴴ = PWᴴ
4040

src/pullbacks/svd.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@ function svd_pullback!(
2626
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
2727
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
2828
)
29-
3029
# Extract the SVD components
3130
U, Smat, Vᴴ = USVᴴ
3231
m, n = size(U, 1), size(Vᴴ, 2)
33-
(m, n) == size(ΔA) || throw(DimensionMismatch())
32+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
3433
minmn = min(m, n)
3534
S = diagview(Smat)
36-
length(S) == minmn || throw(DimensionMismatch())
35+
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
3736
r = searchsortedlast(S, rank_atol; rev = true) # rank
3837
Ur = view(U, :, 1:r)
3938
Vᴴr = view(Vᴴ, 1:r, :)
@@ -44,22 +43,22 @@ function svd_pullback!(
4443
UΔU = fill!(similar(U, (r, r)), 0)
4544
VΔV = fill!(similar(Vᴴ, (r, r)), 0)
4645
if !iszerotangent(ΔU)
47-
m == size(ΔU, 1) || throw(DimensionMismatch())
46+
m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)"))
4847
pU = size(ΔU, 2)
49-
pU > r && throw(DimensionMismatch())
48+
pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)"))
5049
indU = axes(U, 2)[ind]
51-
length(indU) == pU || throw(DimensionMismatch())
50+
length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))"))
5251
UΔUp = view(UΔU, :, indU)
5352
mul!(UΔUp, Ur', ΔU)
5453
# ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
5554
ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
5655
end
5756
if !iszerotangent(ΔVᴴ)
58-
n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
57+
n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)"))
5958
pV = size(ΔVᴴ, 1)
60-
pV > r && throw(DimensionMismatch())
59+
pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)"))
6160
indV = axes(Vᴴ, 1)[ind]
62-
length(indV) == pV || throw(DimensionMismatch())
61+
length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))"))
6362
VΔVp = view(VΔV, :, indV)
6463
mul!(VΔVp, Vᴴr, ΔVᴴ')
6564
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
@@ -82,7 +81,7 @@ function svd_pullback!(
8281
ΔS = diagview(ΔSmat)
8382
pS = length(ΔS)
8483
indS = axes(S, 1)[ind]
85-
length(indS) == pS || throw(DimensionMismatch())
84+
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
8685
view(diagview(UdΔAV), indS) .+= real.(ΔS)
8786
end
8887
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA

test/ad_utils.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
function remove_svdgauge_dependence!(
2+
ΔU, ΔVᴴ, U, S, Vᴴ;
3+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
4+
)
5+
gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true)
6+
gaugepart = project_antihermitian!(gaugepart)
7+
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
8+
mul!(ΔU, U, gaugepart, -1, 1)
9+
return ΔU, ΔVᴴ
10+
end
11+
function remove_eiggauge_dependence!(
12+
ΔV, D, V;
13+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
14+
)
15+
gaugepart = V' * ΔV
16+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
17+
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
18+
return ΔV
19+
end
20+
function remove_eighgauge_dependence!(
21+
ΔV, D, V;
22+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
23+
)
24+
gaugepart = V' * ΔV
25+
gaugepart = project_antihermitian!(gaugepart)
26+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
27+
mul!(ΔV, V, gaugepart, -1, 1)
28+
return ΔV
29+
end
30+
31+
precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T)))

test/chainrules.jl

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
66
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
77
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
88

9-
function remove_svdgauge_dependence!(
10-
ΔU, ΔVᴴ, U, S, Vᴴ;
11-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
12-
)
13-
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
14-
gaugepart = (gaugepart - gaugepart') / 2
15-
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
16-
mul!(ΔU, U, gaugepart, -1, 1)
17-
return ΔU, ΔVᴴ
18-
end
19-
function remove_eiggauge_dependence!(
20-
ΔV, D, V;
21-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
22-
)
23-
gaugepart = V' * ΔV
24-
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
25-
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
26-
return ΔV
27-
end
28-
function remove_eighgauge_dependence!(
29-
ΔV, D, V;
30-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
31-
)
32-
gaugepart = V' * ΔV
33-
gaugepart = (gaugepart - gaugepart') / 2
34-
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
35-
mul!(ΔV, V, gaugepart, -1, 1)
36-
return ΔV
37-
end
38-
39-
precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
40-
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))
9+
include("ad_utils.jl")
4110

4211
for f in
4312
(

0 commit comments

Comments
 (0)