@@ -36,7 +36,7 @@ function svd_pullback!(
3636 minmn = min(m, n)
3737 S = diagview(Smat)
3838 length(S) == minmn || throw(DimensionMismatch())
39- r = searchsortedlast(S, tol ; rev = true ) # rank
39+ r = searchsortedlast(S, rank_atol ; rev = true ) # rank
4040 Ur = view(U, :, 1 : r)
4141 Vᴴr = view(Vᴴ, 1 : r, :)
4242 Sr = view(S, 1 : r)
@@ -53,7 +53,8 @@ function svd_pullback!(
5353 length(indU) == pU || throw(DimensionMismatch())
5454 UΔUp = view(UΔU, :, indU)
5555 mul!(UΔUp, Ur' , ΔU)
56- mul!(ΔU, Ur, UΔUp, -1, 1)
56+ # ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
57+ ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
5758 end
5859 if !iszerotangent(ΔVᴴ)
5960 n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
@@ -63,7 +64,8 @@ function svd_pullback!(
6364 length(indV) == pV || throw(DimensionMismatch())
6465 VΔVp = view(VΔV, :, indV)
6566 mul!(VΔVp, Vᴴr, ΔVᴴ' )
66- mul!(ΔVᴴ, VΔVp' , Vᴴr, -1, 1)
67+ # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
68+ ΔVᴴ = mul!(copy(ΔVᴴ), VΔVp' , Vᴴr, -1, 1)
6769 end
6870
6971 # Project onto antihermitian part; hermitian part outside of Grassmann tangent space
@@ -152,7 +154,8 @@ function svd_trunc_pullback!(
152154 if ! iszerotangent(ΔVᴴ)
153155 (p, n) == size(ΔVᴴ) || throw(DimensionMismatch())
154156 mul!(VΔV, Vᴴ, ΔVᴴ' )
155- mul!(ΔVᴴ, VΔV' , Vᴴ, - 1 , 1 )
157+ # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
158+ ΔVᴴ = mul!(copy(ΔVᴴ), VΔV' , Vᴴ, -1, 1)
156159 end
157160
158161 # Project onto antihermitian part; hermitian part outside of Grassmann tangent space
0 commit comments