Skip to content

Commit 78cfd67

Browse files
committed
Some QR progress
1 parent d8e6d6f commit 78cfd67

File tree

4 files changed

+88
-83
lines changed

4 files changed

+88
-83
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoi
2525
@eval begin
2626
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
2727
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
28-
A, dA = arrayify(A_dA)
29-
dA .= zero(eltype(A))
30-
args = Mooncake.primal(args_dargs)
31-
dargs = Mooncake.tangent(args_dargs)
32-
arg1, darg1 = arrayify(args[1], dargs[1])
33-
arg2, darg2 = arrayify(args[2], dargs[2])
28+
A, dA = arrayify(A_dA)
29+
dA .= zero(eltype(A))
30+
args = Mooncake.primal(args_dargs)
31+
dargs = Mooncake.tangent(args_dargs)
32+
arg1, darg1 = arrayify(args[1], dargs[1])
33+
arg2, darg2 = arrayify(args[2], dargs[2])
3434
function $adj(::Mooncake.NoRData)
3535
dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
3636
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
@@ -42,10 +42,10 @@ for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoi
4242
end
4343
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
4444
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
45-
A, dA = arrayify(A_dA)
46-
args = Mooncake.primal(args_dargs)
47-
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
48-
dargs = Mooncake.tangent(args_dargs)
45+
A, dA = arrayify(A_dA)
46+
args = Mooncake.primal(args_dargs)
47+
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
48+
dargs = Mooncake.tangent(args_dargs)
4949
arg1, darg1 = arrayify(args[1], dargs[1])
5050
arg2, darg2 = arrayify(args[2], dargs[2])
5151
darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2))

src/pushforwards/lq.jl

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,68 @@
1-
function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol)
1+
#=function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol)
22
L, Q = LQ
3+
dL, dQ = dLQ
34
m = size(L, 1)
45
n = size(Q, 2)
56
minmn = min(m, n)
67
Ld = diagview(L)
78
p = findlast(>=(rank_atol) ∘ abs, Ld)
89
10+
if p == minmn && size(L,1) == size(L,2) # full-rank
11+
invL = inv(L)
12+
dQ .= invL * (dA - dL * Q)
13+
dL = invL * dA * Q'
14+
return (dL, dQ)
15+
end
16+
917
n1 = p
1018
n2 = minmn - p
1119
n3 = n - minmn
1220
m1 = p
1321
m2 = m - p
1422
1523
#####
16-
Q1 = view(Q, 1:n1, 1:n) # full rank portion
17-
Q2 = view(Q, 1:n1+1:n2+n1, 1:n)
18-
L11 = view(L, 1:m, 1:n1)
19-
L12 = view(L, 1:m1, n1+1:n)
24+
Q1 = view(Q, 1:m1, 1:n) # full rank portion
25+
Q2 = view(Q, n1+1:n1+n2, 1:n)
26+
L11 = view(L, 1:m1, 1:n1)
27+
L21 = view(L, (m1+1):m, 1:n1)
2028
21-
dA1 = view(dA, 1:m, 1:n1)
22-
dA2 = view(dA, 1:m, (n1 + 1):n)
29+
dA1 = view(dA, 1:m1, 1:n)
30+
dA2 = view(dA, (m1+1):m, 1:n)
2331
24-
dQ, dR = dQR
25-
dQ1 = view(dQ, 1:m, 1:m1)
26-
dQ2 = view(dQ, 1:m, m1+1:m2+m1)
27-
dR11 = view(dR, 1:m1, 1:n1)
28-
dR12 = view(dR, 1:m1, n1+1:n)
29-
dR22 = view(dR, m1+1:m1+m2, n1+1:n)
32+
dQ1 = view(dQ, 1:n1, 1:n)
33+
dQ2 = view(dQ, n1+1:n1+n2, 1:n)
34+
dL11 = view(dL, 1:m1, 1:n1)
35+
dL21 = view(dL, (m1+1):m, 1:n1)
36+
dL22 = view(dL, (m1+1):m, n1+1:(n1+n2) )
3037
3138
# fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need
32-
invR11 = inv(R11)
33-
tmp = Q1' * dA1 * invR11
34-
Rtmp = tmp + tmp'
35-
diagview(Rtmp) ./= 2
36-
ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp))
37-
#ltRtmp .= zero(eltype(Rtmp))
38-
dR11 .= Rtmp * R11
39-
dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
40-
41-
dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
42-
dQ2 .= Q1 * (Q1' * dQ2)
43-
if size(Q2, 2) > 0
39+
invL11 = inv(L11)
40+
tmp = invL11 * dA1 * Q1'
41+
Ltmp = tmp + tmp'
42+
diagview(Ltmp) ./= 2
43+
utLtmp = view(Ltmp, MatrixAlgebraKit.uppertriangularind(Ltmp))
44+
dL11 .= L11 * Ltmp
45+
dQ1 .= invL11 * dA1 - invL11 * dL11 * Q1
46+
47+
dL21 .= (dA2 - L21 * dQ1) * adjoint(Q1)
48+
dQ2 .= -(dQ2 * Q1') * Q1
49+
if size(Q2, 1) > 0
4450
dQ2 .+= Q2 * (Q2' * dQ2)
4551
end
46-
if m3 > 0 && size(dQ2, 2) > 0
52+
if n3 > 0 && size(dQ2, 1) > 0
4753
# only present for qr_full or rank-deficient qr_compact
48-
Q3 = view(Q, 1:m, m1+m2+1:size(Q, 2))
54+
Q3 = view(Q, (n1+n2+1):n, 1:n)
4955
dQ2 .+= Q3 * (Q3' * dQ2)
5056
end
51-
if !isempty(dR22)
52-
_, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true))
53-
dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
57+
if !isempty(dL22)
58+
_, l22 = qr_full(dA2 - L21 * dQ1 - dL12 * Q1, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true))
59+
dL22 .= view(l22, 1:size(dL22, 1), 1:size(dL22, 2))
5460
end
55-
return (dQ, dR)
61+
return (dL, dQ)
62+
end=#
63+
64+
function lq_pushforward!(dA, A, LQ, dLQ; kwargs...)
65+
qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
5666
end
5767

5868
function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end

src/pushforwards/qr.jl

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol)
22
Q, R = QR
3-
m = size(Q, 1)
4-
n = size(R, 2)
3+
m = size(A, 1)
4+
n = size(A, 2)
55
minmn = min(m, n)
66
Rd = diagview(R)
77
p = findlast(>=(rank_atol) abs, Rd)
@@ -23,7 +23,7 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pull
2323
dQ, dR = dQR
2424
dQ1 = view(dQ, 1:m, 1:m1)
2525
dQ2 = view(dQ, 1:m, m1+1:m2+m1)
26-
dQ3 = m1+m2+1 < size(dQ, 2) ? view(dQ, 1:m, m1+m2+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0))
26+
dQ3 = minmn+1 < size(dQ, 2) ? view(dQ, :, minmn+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0))
2727
dR11 = view(dR, 1:m1, 1:n1)
2828
dR12 = view(dR, 1:m1, n1+1:n)
2929
dR22 = view(dR, m1+1:m1+m2, n1+1:n)
@@ -38,32 +38,23 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pull
3838
dR11 .= Rtmp * R11
3939
dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
4040
dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
41-
dQ2 .= Q1 * (Q1' * dQ2)
41+
dQ2 .= -Q1 * (Q1' * dQ2)
4242
if size(Q2, 2) > 0
4343
dQ2 .+= Q2 * (Q2' * dQ2)
4444
end
45-
if m3 > 0 && size(dQ2, 2) > 0
45+
if m3 > 0 && size(Q, 2) > minmn
4646
# only present for qr_full or rank-deficient qr_compact
47-
Q3 = view(Q, 1:m, m1+m2+1:size(Q, 2))
48-
dQ2 .+= Q3 * (Q3' * dQ2)
47+
Q′ = view(Q, :, 1:minmn)
48+
println("minmn $minmn m $m")
49+
Q3 = view(Q, :, minmn+1:m)
50+
#dQ3 .= Q′ * (Q′' * Q3)
51+
dQ3 .= Q3
4952
end
50-
if !isempty(dR22)
53+
#=if !isempty(dR22)
5154
_, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true))
5255
dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
53-
end
56+
end=#
5457
return (dQ, dR)
5558
end
56-
#=Ac = MatrixAlgebraKit.copy_input(qr_full, Aval)
57-
QR = MatrixAlgebraKit.initialize_output(qr_full!, Aval, alg.val)
58-
Q, R = qr_full!(Ac, QR, alg.val)
59-
Nval = N.val
60-
copy!(Nval, view(Q, 1:size(Aval, 1), (size(Aval, 2) + 1):size(Aval, 1)))
61-
(m, n) = size(Aval)
62-
minmn = min(m, n)
63-
dQ = zeros(eltype(Aval), (m, m))
64-
view(dQ, 1:m, (minmn + 1):m) .= dN
65-
MatrixAlgebraKit.qr_fwd(dA, A.val, (Q, R), (dQ, zeros(eltype(R), size(R))))
66-
dN .= view(dQ, 1:m, (minmn + 1):m)
67-
dA .= zero(eltype(A.val))=#
6859

6960
function qr_null_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) end

test/mooncake.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,27 @@ function make_mooncake_tangent(ΔD::Diagonal{T}) where {T<:Complex}
3030
return Mooncake.build_tangent(typeof(ΔD), diag_tangent)
3131
end
3232

33-
ETs = (Float64, Float32,)# ComplexF64, ComplexF32)
33+
ETs = (Float64,)# Float32,)# ComplexF64, ComplexF32)
3434

3535
@timedtestset "QR AD Rules with eltype $T" for T in ETs
3636
rng = StableRNG(12345)
3737
m = 19
38-
@testset "size ($m, $n)" for n in (17, m, 23)
38+
@testset "size ($m, $n)" for n in (17,)# m, 23)
3939
atol = rtol = m * n * precision(T)
4040
A = randn(rng, T, m, n)
4141
minmn = min(m, n)
4242
@testset for alg in (LAPACK_HouseholderQR(),
43-
LAPACK_HouseholderQR(; positive=true),
43+
#LAPACK_HouseholderQR(; positive=true),
4444
)
45-
@testset "qr_compact" begin
45+
#=@testset "qr_compact" begin
4646
Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol)
47-
end
48-
@testset "qr_null" begin
47+
end=#
48+
#=@testset "qr_null" begin
4949
Q, R = qr_compact(A, alg)
5050
ΔN = Q * randn(rng, T, minmn, max(0, m - minmn))
5151
dN = make_mooncake_tangent(ΔN)
5252
Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; output_tangent = dN, is_primitive=false, atol=atol, rtol=rtol)
53-
end
53+
end=#
5454
@testset "qr_full" begin
5555
Q, R = qr_full(A, alg)
5656
Q1 = view(Q, 1:m, 1:minmn)
@@ -61,9 +61,11 @@ ETs = (Float64, Float32,)# ComplexF64, ComplexF32)
6161
dQ = make_mooncake_tangent(ΔQ)
6262
dR = make_mooncake_tangent(ΔR)
6363
dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR)
64-
Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol)
64+
#Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
65+
#Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
66+
Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, minmn+1:m]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
6567
end
66-
@testset "qr_compact - rank-deficient A" begin
68+
#=@testset "qr_compact - rank-deficient A" begin
6769
r = minmn - 5
6870
Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
6971
Q, R = qr_compact(Ard, alg)
@@ -77,12 +79,13 @@ ETs = (Float64, Float32,)# ComplexF64, ComplexF32)
7779
dQ = make_mooncake_tangent(ΔQ)
7880
dR = make_mooncake_tangent(ΔR)
7981
dQR = Mooncake.build_tangent(typeof((ΔQ,ΔR)), dQ, dR)
80-
Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; output_tangent = dQR, is_primitive=false, atol=atol, rtol=rtol)
81-
end
82+
Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
83+
end=#
8284
end
8385
end
8486
end
8587
88+
#=
8689
@timedtestset "LQ AD Rules with eltype $T" for T in ETs
8790
rng = StableRNG(12345)
8891
m = 19
@@ -99,14 +102,14 @@ end
99102
dL = make_mooncake_tangent(ΔL)
100103
dQ = make_mooncake_tangent(ΔQ)
101104
dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ)
102-
Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode=Mooncake.ReverseMode, is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ)
105+
Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; is_primitive=false, atol=atol, rtol=rtol, output_tangent = dLQ)
103106
end
104-
@testset "lq_null" begin
107+
#=@testset "lq_null" begin
105108
L, Q = lq_compact(A, alg)
106109
ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q
107110
dNᴴ = make_mooncake_tangent(ΔNᴴ)
108-
Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode=Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive=false, atol=atol, rtol=rtol)
109-
end
111+
Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, is_primitive=false, atol=atol, rtol=rtol)
112+
end=#
110113
@testset "lq_full" begin
111114
L, Q = lq_full(A, alg)
112115
Q1 = view(Q, 1:minmn, 1:n)
@@ -117,9 +120,9 @@ end
117120
dL = make_mooncake_tangent(ΔL)
118121
dQ = make_mooncake_tangent(ΔQ)
119122
dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ)
120-
Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode=Mooncake.ReverseMode, output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol)
123+
Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol)
121124
end
122-
@testset "lq_compact - rank-deficient A" begin
125+
#=@testset "lq_compact - rank-deficient A" begin
123126
r = minmn - 5
124127
Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
125128
L, Q = lq_compact(Ard, alg)
@@ -133,12 +136,13 @@ end
133136
dL = make_mooncake_tangent(ΔL)
134137
dQ = make_mooncake_tangent(ΔQ)
135138
dLQ = Mooncake.build_tangent(typeof((ΔL,ΔQ)), dL, dQ)
136-
Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode=Mooncake.ReverseMode, output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol)
137-
end
139+
Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, is_primitive=false, atol=atol, rtol=rtol)
140+
end=#
138141
end
139142
end
140143
end
141-
144+
=#
145+
#=
142146
@timedtestset "EIG AD Rules with eltype $T" for T in ETs
143147
rng = StableRNG(12345)
144148
m = 19
@@ -418,4 +422,4 @@ end
418422
Mooncake.TestUtils.test_rule(rng, (X->right_null(X; kind=:lq)), A; atol=atol, rtol=rtol, is_primitive=false, output_tangent = dNᴴ)
419423
end
420424
end
421-
425+
=#

0 commit comments

Comments
 (0)