Skip to content

Commit 6f72754

Browse files
author
Katharine Hyatt
committed
Refactor checks into functions, rd -> rank_deficient
1 parent c4627bc commit 6f72754

File tree

10 files changed

+149
-114
lines changed

10 files changed

+149
-114
lines changed

src/common/defaults.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)
3434
Default tolerance for deciding what values should be considered equal to 0.
3535
"""
3636
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)
37+
default_pullback_rank_atol(A::Diagonal) = default_pullback_rank_atol(diagview(A))
3738

3839
"""
3940
default_hermitian_tol(A)

src/pullbacks/eig.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
function check_eig_cotangents(D, VᴴΔV; degeneracy_atol::Real = default_pullback_rank_atol(D), gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV))
2+
mask = abs.(transpose(D) .- D) .< degeneracy_atol
3+
# not GPU friendly...
4+
Δgauge = norm(view(VᴴΔV, mask))
5+
Δgauge gauge_atol ||
6+
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
7+
return
8+
end
9+
110
"""
211
eig_pullback!(
312
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
@@ -40,14 +49,7 @@ function eig_pullback!(
4049
indV = axes(V, 2)[ind]
4150
length(indV) == pV || throw(DimensionMismatch())
4251
mul!(view(VᴴΔV, :, indV), V', ΔV)
43-
44-
mask = abs.(transpose(D) .- D) .< degeneracy_atol
45-
if isa(ΔA, Array)
46-
# not GPU friendly...
47-
Δgauge = norm(view(VᴴΔV, mask), Inf)
48-
Δgauge ≤ gauge_atol ||
49-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
50-
end
52+
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
5153
5254
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
5355
@@ -132,10 +134,7 @@ function eig_trunc_pullback!(
132134
if !iszerotangent(ΔV)
133135
(n, p) == size(ΔV) || throw(DimensionMismatch())
134136
VᴴΔV = V' * ΔV
135-
mask = abs.(transpose(D) .- D) .< degeneracy_atol
136-
Δgauge = norm(view(VᴴΔV, mask), Inf)
137-
Δgauge ≤ gauge_atol ||
138-
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
137+
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)
139138
140139
ΔVperp = ΔV - V * inv(G) * VᴴΔV
141140
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
@@ -194,7 +193,6 @@ function eig_vals_pullback!(
194193
ΔA, A, DV, ΔD, ind = Colon();
195194
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
196195
)
197-
198196
ΔDV = (diagonal(ΔD), nothing)
199197
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
200198
end

src/pullbacks/eigh.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
function check_eigh_cotangents(
2+
D, aVᴴΔV;
3+
degeneracy_atol::Real = default_pullback_rank_atol(D),
4+
gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV)
5+
)
6+
mask = abs.(D' .- D) .< degeneracy_atol
7+
Δgauge = norm(view(aVᴴΔV, mask))
8+
Δgauge ≤ gauge_atol ||
9+
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
10+
return
11+
end
12+
113
"""
214
eigh_pullback!(
315
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
@@ -41,12 +53,7 @@ function eigh_pullback!(
4153
length(indV) == pV || throw(DimensionMismatch())
4254
mul!(view(VᴴΔV, :, indV), V', ΔV)
4355
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work
44-
45-
mask = abs.(D' .- D) .< degeneracy_atol
46-
Δgauge = norm(view(aVᴴΔV, mask))
47-
Δgauge gauge_atol ||
48-
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
49-
56+
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)
5057
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)
5158
5259
if !iszerotangent(ΔDmat)
@@ -120,10 +127,7 @@ function eigh_trunc_pullback!(
120127
VᴴΔV = V' * ΔV
121128
aVᴴΔV = project_antihermitian!(VᴴΔV)
122129
123-
mask = abs.(D' .- D) .< degeneracy_atol
124-
Δgauge = norm(view(aVᴴΔV, mask))
125-
Δgauge ≤ gauge_atol ||
126-
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
130+
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)
127131
128132
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)
129133

src/pullbacks/lq.jl

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,42 @@
1+
function check_lq_cotangents(
2+
L, Q, ΔL, ΔQ, minmn::Int, p::Int;
3+
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
4+
)
5+
if minmn > p # case where A is rank-deficient
6+
Δgauge = abs(zero(eltype(Q)))
7+
if !iszerotangent(ΔQ)
8+
# in this case the number Householder reflections will
9+
# change upon small variations, and all of the remaining
10+
# columns of ΔQ should be zero for a gauge-invariant
11+
# cost function
12+
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
13+
Δgauge = max(Δgauge, norm(ΔQ2))
14+
end
15+
if !iszerotangent(ΔL)
16+
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
17+
Δgauge = max(Δgauge, norm(ΔL22))
18+
end
19+
Δgauge gauge_atol ||
20+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
21+
end
22+
return
23+
end
24+
25+
function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1))
26+
# in the case where A is full rank, but there are more columns in Q than in A
27+
# (the case of `lq_full`), there is gauge-invariant information in the
28+
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
29+
# matrix. As the number of Householder reflections is in fixed in the full rank
30+
# case, Q is expected to rotate smoothly (we might even be able to predict) also
31+
# how the full Q2 will change, but this we omit for now, and we consider
32+
# Q2' * ΔQ2 as a gauge dependent quantity.
33+
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
34+
Δgauge gauge_atol ||
35+
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
36+
return
37+
end
38+
39+
140
"""
241
lq_pullback!(
342
ΔA, A, LQ, ΔLQ;
@@ -36,25 +75,7 @@ function lq_pullback!(
3675
ΔA1 = view(ΔA, 1:p, :)
3776
ΔA2 = view(ΔA, (p + 1):m, :)
3877

39-
if isa(ΔA, Array) # not GPU friendly
40-
if minmn > p # case where A is rank-deficient
41-
Δgauge = abs(zero(eltype(Q)))
42-
if !iszerotangent(ΔQ)
43-
# in this case the number Householder reflections will
44-
# change upon small variations, and all of the remaining
45-
# columns of ΔQ should be zero for a gauge-invariant
46-
# cost function
47-
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
48-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
49-
end
50-
if !iszerotangent(ΔL)
51-
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
52-
Δgauge = max(Δgauge, norm(ΔL22, Inf))
53-
end
54-
Δgauge gauge_atol ||
55-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
56-
end
57-
end
78+
check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol)
5879

5980
ΔQ̃ = zero!(similar(Q, (p, n)))
6081
if !iszerotangent(ΔQ)
@@ -71,11 +92,7 @@ function lq_pullback!(
7192
# how the full Q2 will change, but this we omit for now, and we consider
7293
# Q2' * ΔQ2 as a gauge dependent quantity.
7394
ΔQ2Q1ᴴ = ΔQ2 * Q1'
74-
if isa(ΔA, Array) # not GPU friendly
75-
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
76-
Δgauge ≤ gauge_atol ||
77-
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78-
end
95+
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
7996
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
8097
end
8198
end
@@ -108,6 +125,14 @@ function lq_pullback!(
108125
return ΔA
109126
end
110127

128+
function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
129+
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
130+
Δgauge = norm(aNᴴΔN)
131+
Δgauge ≤ gauge_atol ||
132+
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
133+
return
134+
end
135+
111136
"""
112137
lq_null_pullback!(
113138
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
@@ -124,10 +149,7 @@ function lq_null_pullback!(
124149
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
125150
)
126151
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
127-
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
128-
Δgauge = norm(aNᴴΔN)
129-
Δgauge ≤ gauge_atol ||
130-
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
152+
check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol)
131153
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
132154
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
133155
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)

src/pullbacks/polar.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
2222
if !iszerotangent(ΔW)
2323
ΔWP = ΔW / P
2424
WdΔWP = W' * ΔWP
25-
ΔWP .-= W * WdΔWP
25+
ΔWP = mul!(ΔWP, W, WdΔWP, -1, 1)
2626
ΔA .+= ΔWP
2727
end
2828
return ΔA
@@ -48,11 +48,11 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
4848
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
4949
C = sylvester(P, P, M' - M)
5050
C .+= ΔP
51-
ΔA .+= C * Wᴴ
51+
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
5252
if !iszerotangent(ΔWᴴ)
5353
PΔWᴴ = P \ ΔWᴴ
5454
PΔWᴴW = PΔWᴴ * Wᴴ'
55-
PΔWᴴ .-= PΔWᴴW * Wᴴ
55+
PΔWᴴ = mul!(PΔWᴴ, PΔWᴴW, Wᴴ, -1, 1)
5656
ΔA .+= PΔWᴴ
5757
end
5858
return ΔA

src/pullbacks/qr.jl

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,38 @@
1+
function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ))
2+
if minmn > p # case where A is rank-deficient
3+
Δgauge = abs(zero(eltype(Q)))
4+
if !iszerotangent(ΔQ)
5+
# in this case the number Householder reflections will
6+
# change upon small variations, and all of the remaining
7+
# columns of ΔQ should be zero for a gauge-invariant
8+
# cost function
9+
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
10+
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
11+
end
12+
if !iszerotangent(ΔR)
13+
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
14+
Δgauge = max(Δgauge, norm(ΔR22, Inf))
15+
end
16+
Δgauge gauge_atol ||
17+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
18+
end
19+
return
20+
end
21+
22+
function check_qr_full_cotangents(Q1, ΔQ2, ΔR, Q1dΔQ2, ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
23+
# in the case where A is full rank, but there are more columns in Q than in A
24+
# (the case of `qr_full`), there is gauge-invariant information in the
25+
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
26+
# matrix. As the number of Householder reflections is in fixed in the full rank
27+
# case, Q is expected to rotate smoothly (we might even be able to predict) also
28+
# how the full Q2 will change, but this we omit for now, and we consider
29+
# Q2' * ΔQ2 as a gauge dependent quantity.
30+
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
31+
Δgauge gauge_atol ||
32+
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
33+
return
34+
end
35+
136
"""
237
qr_pullback!(
338
ΔA, A, QR, ΔQR;
@@ -37,45 +72,16 @@ function qr_pullback!(
3772
ΔA1 = view(ΔA, :, 1:p)
3873
ΔA2 = view(ΔA, :, (p + 1):n)
3974

40-
if isa(ΔA, Array) # not GPU friendly
41-
if minmn > p # case where A is rank-deficient
42-
Δgauge = abs(zero(eltype(Q)))
43-
if !iszerotangent(ΔQ)
44-
# in this case the number Householder reflections will
45-
# change upon small variations, and all of the remaining
46-
# columns of ΔQ should be zero for a gauge-invariant
47-
# cost function
48-
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
49-
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
50-
end
51-
if !iszerotangent(ΔR)
52-
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
53-
Δgauge = max(Δgauge, norm(ΔR22, Inf))
54-
end
55-
Δgauge gauge_atol ||
56-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
57-
end
58-
end
75+
check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol)
5976

6077
ΔQ̃ = zero!(similar(Q, (m, p)))
6178
if !iszerotangent(ΔQ)
6279
ΔQ̃ .= view(ΔQ, :, 1:p)
6380
if p < size(Q, 2)
6481
Q2 = view(Q, :, (p + 1):size(Q, 2))
6582
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
66-
# in the case where A is full rank, but there are more columns in Q than in A
67-
# (the case of `qr_full`), there is gauge-invariant information in the
68-
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
69-
# matrix. As the number of Householder reflections is in fixed in the full rank
70-
# case, Q is expected to rotate smoothly (we might even be able to predict) also
71-
# how the full Q2 will change, but this we omit for now, and we consider
72-
# Q2' * ΔQ2 as a gauge dependent quantity.
7383
Q1dΔQ2 = Q1' * ΔQ2
74-
if isa(ΔA, Array) # not GPU friendly
75-
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
76-
Δgauge ≤ gauge_atol ||
77-
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
78-
end
84+
check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol)
7985
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
8086
end
8187
end
@@ -91,9 +97,9 @@ function qr_pullback!(
9197
M = zero!(similar(R, (p, p)))
9298
if !iszerotangent(ΔR)
9399
ΔR11 = view(ΔR, 1:p, 1:p)
94-
M += ΔR11 * R11'
100+
M = mul!(M, ΔR11, R11', 1, 1)
95101
end
96-
M -= Q1' * ΔQ̃
102+
M = mul!(M, Q1', ΔQ̃, -1, 1)
97103
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
98104
if eltype(M) <: Complex
99105
Md = diagview(M)
@@ -108,6 +114,14 @@ function qr_pullback!(
108114
return ΔA
109115
end
110116

117+
function check_qr_null_cotangents(N, ΔN; gauge_atol::Real = default_pullback_gauge_atol(ΔN))
118+
aNᴴΔN = project_antihermitian!(N' * ΔN)
119+
Δgauge = norm(aNᴴΔN)
120+
Δgauge ≤ gauge_atol ||
121+
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
122+
return
123+
end
124+
111125
"""
112126
qr_null_pullback!(
113127
ΔA::AbstractMatrix, A, N, ΔN;
@@ -124,11 +138,7 @@ function qr_null_pullback!(
124138
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
125139
)
126140
if !iszerotangent(ΔN) && size(N, 2) > 0
127-
aNᴴΔN = project_antihermitian!(N' * ΔN)
128-
Δgauge = norm(aNᴴΔN)
129-
Δgauge ≤ gauge_atol ||
130-
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
131-
141+
check_qr_null_cotangents(N, ΔN; gauge_atol)
132142
Q, R = qr_compact(A; positive = true)
133143
X = rdiv!(ΔN' * Q, UpperTriangular(R)')
134144
ΔA = mul!(ΔA, N, X, -1, 1)

0 commit comments

Comments
 (0)