Skip to content

Commit c587b78

Browse files
committed
Some progress on Diagonals
1 parent 98da0ec commit c587b78

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
1818
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
1919
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
2020
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
21-
dAc = Mooncake.zero_tangent(Ac)
21+
dAc = Mooncake.fdata(Mooncake.zero_tangent(Ac))
2222
function copy_input_pb(::NoRData)
2323
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
2424
return NoRData(), NoRData(), NoRData()

test/testsuite/ad_utils.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ function ad_qr_compact_setup(A)
6969
return QR, (ΔQ, ΔR)
7070
end
7171

72+
function ad_qr_compact_setup(A::Diagonal)
73+
m, n = size(A)
74+
QR = qr_compact(A)
75+
T = eltype(A)
76+
ΔQ = Diagonal(randn!(similar(A, T, m)))
77+
ΔR = Diagonal(randn!(similar(A, T, m)))
78+
return QR, (ΔQ, ΔR)
79+
end
80+
7281
function ad_qr_null_setup(A)
7382
m, n = size(A)
7483
minmn = min(m, n)
@@ -92,6 +101,15 @@ function ad_qr_full_setup(A)
92101
return (Q, R), (ΔQ, ΔR)
93102
end
94103
104+
function ad_qr_full_setup(A::Diagonal)
105+
m, n = size(A)
106+
QR = qr_full(A)
107+
T = eltype(A)
108+
ΔQ = Diagonal(randn!(similar(A, T, m)))
109+
ΔR = Diagonal(randn!(similar(A, T, m)))
110+
return QR, (ΔQ, ΔR)
111+
end
112+
95113
function ad_qr_rd_compact_setup(A)
96114
m, n = size(A)
97115
minmn = min(m, n)
@@ -110,6 +128,22 @@ function ad_qr_rd_compact_setup(A)
110128
return (Q, R), (ΔQ, ΔR)
111129
end
112130
131+
function ad_qr_rd_compact_setup(A::Diagonal)
132+
m, n = size(A)
133+
minmn = min(m, n)
134+
T = eltype(A)
135+
r = minmn - 5
136+
Ard = Diagonal(similar(A.diag, T, m))
137+
copyto!(Ard.diag[1:r], randn!(similar(A.diag, T, r)))
138+
Ard.diag[r+1:m] .= zero(T)
139+
Q, R = qr_compact(Ard)
140+
ΔQ = Diagonal(randn!(similar(A, T, m)))
141+
ΔR = Diagonal(randn!(similar(A, T, m)))
142+
ΔQ.diag[r+1:m] .= zero(T)
143+
ΔR.diag[r+1:m] .= zero(T)
144+
return (Q, R), (ΔQ, ΔR)
145+
end
146+
113147
function ad_lq_compact_setup(A)
114148
m, n = size(A)
115149
minmn = min(m, n)
@@ -120,6 +154,15 @@ function ad_lq_compact_setup(A)
120154
return LQ, (ΔL, ΔQ)
121155
end
122156
157+
function ad_lq_compact_setup(A::Diagonal)
158+
m, n = size(A)
159+
LQ = lq_compact(A)
160+
T = eltype(A)
161+
ΔL = Diagonal(randn!(similar(A, T, m)))
162+
ΔQ = Diagonal(randn!(similar(A, T, m)))
163+
return LQ, (ΔL, ΔQ)
164+
end
165+
123166
function ad_lq_null_setup(A)
124167
m, n = size(A)
125168
minmn = min(m, n)
@@ -143,6 +186,15 @@ function ad_lq_full_setup(A)
143186
return (L, Q), (ΔL, ΔQ)
144187
end
145188

189+
function ad_lq_full_setup(A::Diagonal)
190+
m, n = size(A)
191+
T = eltype(A)
192+
L, Q = lq_full(A)
193+
ΔL = Diagonal(randn!(similar(A, T, m)))
194+
ΔQ = Diagonal(randn!(similar(A, T, m)))
195+
return (L, Q), (ΔL, ΔQ)
196+
end
197+
146198
function ad_lq_rd_compact_setup(A)
147199
m, n = size(A)
148200
minmn = min(m, n)
@@ -160,6 +212,22 @@ function ad_lq_rd_compact_setup(A)
160212
return (L, Q), (ΔL, ΔQ)
161213
end
162214

215+
function ad_lq_rd_compact_setup(A::Diagonal)
216+
m, n = size(A)
217+
minmn = min(m, n)
218+
T = eltype(A)
219+
r = minmn - 5
220+
Ard = Diagonal(similar(A.diag, T, m))
221+
copyto!(Ard.diag[1:r], randn!(similar(A.diag, T, r)))
222+
Ard.diag[r+1:m] .= zero(T)
223+
L, Q = lq_compact(Ard)
224+
ΔL = Diagonal(randn!(similar(A, T, m)))
225+
ΔQ = Diagonal(randn!(similar(A, T, m)))
226+
ΔL.diag[r+1:m] .= zero(T)
227+
ΔQ.diag[r+1:m] .= zero(T)
228+
return (L, Q), (ΔL, ΔQ)
229+
end
230+
163231
function ad_eig_full_setup(A)
164232
m, n = size(A)
165233
T = eltype(A)

test/testsuite/mooncake.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(
6565

6666
# no `alg` argument
6767
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata)
68-
dA_copy = make_mooncake_tangent(copy(ΔA))
68+
dA_copy = make_mooncake_fdata(copy(ΔA))
6969
A_copy = copy(A)
7070
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
7171
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy))
@@ -75,7 +75,7 @@ end
7575

7676
# `alg` argument
7777
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
78-
dA_copy = make_mooncake_tangent(copy(ΔA))
78+
dA_copy = make_mooncake_fdata(copy(ΔA))
7979
A_copy = copy(A)
8080
dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
8181
copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
@@ -84,7 +84,7 @@ function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
8484
end
8585

8686
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata)
87-
dA_inplace = make_mooncake_tangent(copy(ΔA))
87+
dA_inplace = make_mooncake_fdata(copy(ΔA))
8888
A_inplace = copy(A)
8989
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
9090
# not every f! has a handwritten rrule!!
@@ -103,7 +103,7 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata)
103103
end
104104

105105
function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
106-
dA_inplace = make_mooncake_tangent(copy(ΔA))
106+
dA_inplace = make_mooncake_fdata(copy(ΔA))
107107
A_inplace = copy(A)
108108
dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs))
109109
# not every f! has a handwritten rrule!!
@@ -143,9 +143,9 @@ function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Moo
143143
sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)}
144144
rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode)
145145
rrule = Mooncake.build_rrule(rvs_interp, sig)
146-
ΔA = randn!(similar(A))
146+
ΔA = A isa Diagonal ? Diagonal(randn!(similar(A.diag))) : randn!(similar(A))
147147

148-
dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
148+
dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
149149
dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata)
150150

151151
dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2]

0 commit comments

Comments
 (0)