Skip to content

Commit 0d78d08

Browse files
authored
Pullbacks for Diagonal inputs (#156)
* naively specialize diagonal pullback * add eig diagonal pullback * add svd diagonal pullback
1 parent 083749e commit 0d78d08

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

src/pullbacks/eig.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ function eig_pullback!(
7878
end
7979
return ΔA
8080
end
81+
function eig_pullback!(
82+
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
83+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
84+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
85+
)
86+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
87+
ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
88+
diagview(ΔA) .+= diagview(ΔA_full)
89+
return ΔA
90+
end
8191
8292
"""
8393
eig_trunc_pullback!(
@@ -151,6 +161,16 @@ function eig_trunc_pullback!(
151161
end
152162
return ΔA
153163
end
164+
function eig_trunc_pullback!(
165+
ΔA::Diagonal, A, DV, ΔDV;
166+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
167+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
168+
)
169+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
170+
ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
171+
diagview(ΔA) .+= diagview(ΔA_full)
172+
return ΔA
173+
end
154174
155175
"""
156176
eig_vals_pullback!(

src/pullbacks/eigh.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ function eigh_pullback!(
6868
end
6969
return ΔA
7070
end
71+
function eigh_pullback!(
72+
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
73+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
74+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
75+
)
76+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
77+
ΔA_full = eigh_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
78+
diagview(ΔA) .+= diagview(ΔA_full)
79+
return ΔA
80+
end
7181
7282
"""
7383
eigh_trunc_pullback!(
@@ -141,6 +151,16 @@ function eigh_trunc_pullback!(
141151
end
142152
return ΔA
143153
end
154+
function eigh_trunc_pullback!(
155+
ΔA::Diagonal, A, DV, ΔDV;
156+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
157+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
158+
)
159+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
160+
ΔA_full = eigh_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
161+
diagview(ΔA) .+= diagview(ΔA_full)
162+
return ΔA
163+
end
144164
145165
"""
146166
eigh_vals_pullback!(

src/pullbacks/svd.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ function svd_pullback!(
9999
end
100100
return ΔA
101101
end
102+
function svd_pullback!(
103+
ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon();
104+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
105+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
106+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
107+
)
108+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
109+
ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol)
110+
diagview(ΔA) .+= diagview(ΔA_full)
111+
return ΔA
112+
end
102113
103114
"""
104115
svd_trunc_pullback!(
@@ -201,6 +212,17 @@ function svd_trunc_pullback!(
201212
ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1)
202213
return ΔA
203214
end
215+
function svd_trunc_pullback!(
216+
ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ;
217+
rank_atol::Real = 0,
218+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
219+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
220+
)
221+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
222+
ΔA_full = svd_trunc_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ; rank_atol, degeneracy_atol, gauge_atol)
223+
diagview(ΔA) .+= diagview(ΔA_full)
224+
return ΔA
225+
end
204226

205227
"""
206228
svd_vals_pullback!(

0 commit comments

Comments
 (0)