Skip to content

Commit 87e8b5d

Browse files
committed
Small fixes to svd_trunc
1 parent 590ef91 commit 87e8b5d

File tree

5 files changed

+44
-24
lines changed

5 files changed

+44
-24
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
33
using Mooncake
44
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
55
using MatrixAlgebraKit
6-
using MatrixAlgebraKit: inv_safe, diagview, copy_input, zero!
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input, zero!, truncate, truncation_error!
77
using MatrixAlgebraKit: qr_pullback!, qr_pushforward!, lq_pullback!, lq_pushforward!
88
using MatrixAlgebraKit: qr_null_pullback!, qr_null_pushforward!, lq_null_pullback!, lq_null_pushforward!
99
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
@@ -419,14 +419,25 @@ function Mooncake.frule!!(::Dual{typeof(svd_trunc)}, A_dA::Dual, alg_dalg::Dual)
419419
# compute primal
420420
A, dA = Mooncake.arrayify(A_dA)
421421
alg = Mooncake.primal(alg_dalg)
422-
output = svd_trunc(A, alg)
422+
USVᴴ = svd_compact(A, alg.alg)
423+
U, S, Vᴴ = USVᴴ
424+
dUfull = zeros(eltype(U), size(U))
425+
dSfull = Diagonal(zeros(eltype(S), length(diagview(S))))
426+
dVᴴfull = zeros(eltype(Vᴴ), size(Vᴴ))
427+
svd_pushforward!(dA, A, (U, S, Vᴴ), (dUfull, dSfull, dVᴴfull))
428+
429+
USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc)
430+
ϵ = truncation_error!(diagview(S), ind)
431+
output = (USVᴴtrunc..., ϵ)
423432
output_dual = Mooncake.zero_dual(output)
424-
U, S, Vᴴ, ϵ = output
433+
Utrunc, Strunc, Vᴴtrunc, ϵ = output
425434
dU_, dS_, dVᴴ_, dϵ = Mooncake.tangent(output_dual)
426-
dU = arrayify(U, dU_)
427-
dS = arrayify(S, dS_)
428-
dVᴴ = arrayify(Vᴴ, dVᴴ_)
429-
svd_trunc_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
435+
Utrunc, dU = arrayify(Utrunc, dU_)
436+
Strunc, dS = arrayify(Strunc, dS_)
437+
Vᴴtrunc, dVᴴ = arrayify(Vᴴtrunc, dVᴴ_)
438+
dU .= view(dUfull, :, ind)
439+
diagview(dS) .= view(diagview(dSfull), ind)
440+
dVᴴ .= view(dVᴴfull, ind, :)
430441
return output_dual
431442
end
432443

src/pushforwards/eigh.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
2-
D, V = DV
3-
dD, dV = dDV
4-
tmpV = V \ dA
5-
∂K = tmpV * V
6-
∂Kdiag = diag(∂K)
7-
dD.diag .= real.(∂Kdiag)
8-
dDD = transpose(diagview(D)) .- diagview(D)
9-
F = one(eltype(dDD)) ./ dDD
10-
diagview(F) .= zero(eltype(F))
11-
∂K .*= F
12-
∂V = mul!(tmpV, V, ∂K)
13-
copyto!(dV, ∂V)
2+
D, V = DV
3+
dD, dV = dDV
4+
tmpV = V \ dA
5+
∂K = tmpV * V
6+
∂Kdiag = diag(∂K)
7+
diagview(dD) .= real.(∂Kdiag)
8+
if !iszerotangent(dV)
9+
dDD = transpose(diagview(D)) .- diagview(D)
10+
F = one(eltype(dDD)) ./ dDD
11+
diagview(F) .= zero(eltype(F))
12+
∂K .*= F
13+
∂V = mul!(tmpV, V, ∂K)
14+
copyto!(dV, ∂V)
15+
end
1416
return (dD, dV)
1517
end
1618

src/pushforwards/qr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=default_pullback_gauge_atol(Q
5050
dQ3 .= Q3
5151
end
5252
if !isempty(dR22)
53-
_, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, LAPACK_HouseholderQR(; positive=true))
53+
_, r22 = qr_compact(dA2 - dQ1*R12 - Q1*dR12; positive=true)
5454
dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
5555
end
5656
return (dQ, dR)

src/pushforwards/svd.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,6 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol=default_pullback_r
7777
return (ΔU, ΔS, ΔVᴴ)
7878
end
7979

80-
function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol=default_pullback_rank_atol(A), kwargs...) end
80+
function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol=default_pullback_rank_atol(A), kwargs...)
81+
82+
end

test/mooncake.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ end
151151
dQ = make_mooncake_tangent(copy(ΔQ))
152152
dR = make_mooncake_tangent(copy(ΔR))
153153
dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR)
154-
Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; output_tangent = dQR, atol=atol, rtol=rtol)
155-
#Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
156-
#Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
154+
#Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; output_tangent = dQR, atol=atol, rtol=rtol)
155+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
156+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
157157
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, minmn+1:m]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
158158
test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg)
159159
end
@@ -173,6 +173,11 @@ end
173173
dR = make_mooncake_tangent(copy(ΔR))
174174
dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR)
175175
Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; output_tangent = dQR, atol=atol, rtol=rtol)
176+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
177+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][1:r, 1:r]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
178+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][r+1:m, 1:r]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
179+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][1:r, r+1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
180+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][r+1:m, r+1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
176181
test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg)
177182
end
178183
end

0 commit comments

Comments
 (0)