You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/pullbacks/svd.jl
+9-10Lines changed: 9 additions & 10 deletions
Original file line number
Diff line number
Diff line change
@@ -28,14 +28,13 @@ function svd_pullback!(
28
28
degeneracy_atol::Real= tol,
29
29
gauge_atol::Real= tol
30
30
)
31
-
32
31
# Extract the SVD components
33
32
U, Smat, Vᴴ = USVᴴ
34
33
m, n = size(U, 1), size(Vᴴ, 2)
35
-
(m, n) == size(ΔA) || throw(DimensionMismatch())
34
+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
36
35
minmn = min(m, n)
37
36
S = diagview(Smat)
38
-
length(S) == minmn || throw(DimensionMismatch())
37
+
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
39
38
r = searchsortedlast(S, rank_atol; rev =true) # rank
40
39
Ur = view(U, :, 1:r)
41
40
Vᴴr = view(Vᴴ, 1:r, :)
@@ -46,22 +45,22 @@ function svd_pullback!(
46
45
UΔU = fill!(similar(U, (r, r)), 0)
47
46
VΔV = fill!(similar(Vᴴ, (r, r)), 0)
48
47
if!iszerotangent(ΔU)
49
-
m == size(ΔU, 1) || throw(DimensionMismatch())
48
+
m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)"))
50
49
pU = size(ΔU, 2)
51
-
pU > r && throw(DimensionMismatch())
50
+
pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)"))
52
51
indU = axes(U, 2)[ind]
53
-
length(indU) == pU || throw(DimensionMismatch())
52
+
length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))"))
54
53
UΔUp = view(UΔU, :, indU)
55
54
mul!(UΔUp, Ur', ΔU)
56
55
# ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
57
56
ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
58
57
end
59
58
if !iszerotangent(ΔVᴴ)
60
-
n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
59
+
n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)"))
61
60
pV = size(ΔVᴴ, 1)
62
-
pV > r && throw(DimensionMismatch())
61
+
pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)"))
63
62
indV = axes(Vᴴ, 1)[ind]
64
-
length(indV) == pV || throw(DimensionMismatch())
63
+
length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))"))
65
64
VΔVp = view(VΔV, :, indV)
66
65
mul!(VΔVp, Vᴴr, ΔVᴴ')
67
66
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
@@ -84,7 +83,7 @@ function svd_pullback!(
84
83
ΔS = diagview(ΔSmat)
85
84
pS = length(ΔS)
86
85
indS = axes(S, 1)[ind]
87
-
length(indS) == pS || throw(DimensionMismatch())
86
+
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
88
87
view(diagview(UdΔAV), indS) .+= real.(ΔS)
89
88
end
90
89
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA
0 commit comments