@@ -99,6 +99,17 @@ function svd_pullback!(
9999 end
100100 return ΔA
101101end
102+ function svd_pullback!(
103+ ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon();
104+ rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
105+ degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
106+ gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
107+ )
108+ ΔA_full = zero!(similar(ΔA, size(ΔA)))
109+ ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol)
110+ diagview(ΔA) .+= diagview(ΔA_full)
111+ return ΔA
112+ end
102113
103114"""
104115 svd_trunc_pullback!(
@@ -201,6 +212,17 @@ function svd_trunc_pullback!(
201212 ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1 , 1 )
202213 return ΔA
203214end
215+ function svd_trunc_pullback!(
216+ ΔA:: Diagonal , A, USVᴴ, ΔUSVᴴ;
217+ rank_atol:: Real = 0 ,
218+ degeneracy_atol:: Real = default_pullback_rank_atol(USVᴴ[2 ]),
219+ gauge_atol:: Real = default_pullback_gauge_atol(ΔUSVᴴ[1 ], ΔUSVᴴ[3 ])
220+ )
221+ ΔA_full = zero!(similar(ΔA, size(ΔA)))
222+ ΔA_full = svd_trunc_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ; rank_atol, degeneracy_atol, gauge_atol)
223+ diagview(ΔA) .+ = diagview(ΔA_full)
224+ return ΔA
225+ end
204226
205227"""
206228 svd_vals_pullback!(
0 commit comments