Skip to content

Commit c99084b

Browse files
authored
update pullback tolerances (#92)
* update pullback tolerances * update docstrings * make gauge tolerances relative * make hermitian tolerances relative
1 parent 5d36c98 commit c99084b

File tree

6 files changed

+74
-73
lines changed

6 files changed

+74
-73
lines changed

src/common/defaults.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,34 @@ quantity needs to be computed.
1010
defaulttol(x::Any) = eps(real(float(one(eltype(x)))))^(2 / 3)
1111

1212
"""
13-
default_pullback_gaugetol(a)
13+
default_pullback_gauge_atol(ΔA...)
1414
1515
Default tolerance for deciding to warn if incoming adjoints of a pullback rule
1616
has components that are not gauge-invariant.
1717
"""
18-
function default_pullback_gaugetol(a)
19-
n = norm(a, Inf)
20-
return eps(eltype(n))^(3 / 4) * max(n, one(n))
18+
default_pullback_gauge_atol(A) = iszerotangent(A) ? 0 : eps(norm(A, Inf))^(3 / 4)
19+
function default_pullback_gauge_atol(A, As...)
20+
As′ = filter(!iszerotangent, (A, As...))
21+
return isempty(As′) ? 0 : eps(norm(As′, Inf))^(3 / 4)
2122
end
2223

24+
"""
25+
default_pullback_degeneracy_atol(A)
26+
27+
Default tolerance for deciding when values should be considered as degenerate.
28+
"""
29+
default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)
30+
31+
"""
32+
default_pullback_rank_atol(A)
33+
34+
Default tolerance for deciding what values should be considered equal to 0.
35+
"""
36+
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)
37+
2338
"""
2439
default_hermitian_tol(A)
2540
2641
Default tolerance for deciding to warn if the provided `A` is not hermitian.
2742
"""
28-
function default_hermitian_tol(A)
29-
n = norm(A, Inf)
30-
return eps(eltype(n))^(3 / 4) * max(n, one(n))
31-
end
43+
default_hermitian_tol(A) = eps(norm(A, Inf))^(3 / 4)

src/pullbacks/eig.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
22
eig_pullback!(
33
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
4-
tol = default_pullback_gaugetol(DV[1]),
5-
degeneracy_atol = tol,
6-
gauge_atol = tol
4+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
5+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
76
)
87
98
Adds the pullback from the full eigenvalue decomposition of `A` to `ΔA`, given the output
@@ -22,9 +21,8 @@ not small compared to `gauge_atol`.
2221
"""
2322
function eig_pullback!(
2423
ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon();
25-
tol::Real = default_pullback_gaugetol(DV[1]),
26-
degeneracy_atol::Real = tol,
27-
gauge_atol::Real = tol
24+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
25+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
2826
)
2927

3028
# Basic size checks and determination
@@ -84,9 +82,8 @@ end
8482
"""
8583
eig_trunc_pullback!(
8684
ΔA::AbstractMatrix, ΔDV, A, DV;
87-
tol = default_pullback_gaugetol(DV[1]),
88-
degeneracy_atol = tol,
89-
gauge_atol = tol
85+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
86+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
9087
)
9188
9289
Adds the pullback from the truncated eigenvalue decomposition of `A` to `ΔA`, given the
@@ -106,9 +103,8 @@ not small compared to `gauge_atol`.
106103
"""
107104
function eig_trunc_pullback!(
108105
ΔA::AbstractMatrix, A, DV, ΔDV;
109-
tol::Real = default_pullback_gaugetol(DV[1]),
110-
degeneracy_atol::Real = tol,
111-
gauge_atol::Real = tol
106+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
107+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
112108
)
113109
114110
# Basic size checks and determination

src/pullbacks/eigh.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
22
eigh_pullback!(
33
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
4-
tol = default_pullback_gaugetol(DV[1]),
5-
degeneracy_atol = tol,
6-
gauge_atol = tol
4+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
5+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
76
)
87
98
Adds the pullback from the Hermitian eigenvalue decomposition of `A` to `ΔA`, given the
@@ -22,9 +21,8 @@ anti-hermitian part of `V' * ΔV`, restricted to rows `i` and columns `j` for wh
2221
"""
2322
function eigh_pullback!(
2423
ΔA::AbstractMatrix, A, DV, ΔDV, ind = Colon();
25-
tol::Real = default_pullback_gaugetol(DV[1]),
26-
degeneracy_atol::Real = tol,
27-
gauge_atol::Real = tol
24+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
25+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
2826
)
2927

3028
# Basic size checks and determination
@@ -49,7 +47,7 @@ function eigh_pullback!(
4947
Δgauge < gauge_atol ||
5048
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5149

52-
aVᴴΔV .*= inv_safe.(D' .- D, tol)
50+
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)
5351
5452
if !iszerotangent(ΔDmat)
5553
ΔDvec = diagview(ΔDmat)
@@ -74,9 +72,8 @@ end
7472
"""
7573
eigh_trunc_pullback!(
7674
ΔA::AbstractMatrix, A, DV, ΔDV;
77-
tol=default_pullback_gaugetol(DV[1]),
78-
degeneracy_atol=tol,
79-
gauge_atol=tol
75+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
76+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
8077
)
8178
8279
Adds the pullback from the truncated Hermitian eigenvalue decomposition of `A` to `ΔA`,
@@ -96,9 +93,8 @@ not small compared to `gauge_atol`.
9693
"""
9794
function eigh_trunc_pullback!(
9895
ΔA::AbstractMatrix, A, DV, ΔDV;
99-
tol::Real = default_pullback_gaugetol(DV[1]),
100-
degeneracy_atol::Real = tol,
101-
gauge_atol::Real = tol
96+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
97+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
10298
)
10399
104100
# Basic size checks and determination
@@ -119,7 +115,7 @@ function eigh_trunc_pullback!(
119115
Δgauge < gauge_atol ||
120116
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
121117
122-
aVᴴΔV .*= inv_safe.(D' .- D, tol)
118+
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)
123119
124120
if !iszerotangent(ΔDmat)
125121
ΔDvec = diagview(ΔDmat)

src/pullbacks/lq.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
22
lq_pullback!(
33
ΔA, A, LQ, ΔLQ;
4-
tol::Real = default_pullback_gaugetol(LQ[1]),
5-
rank_atol::Real = tol,
6-
gauge_atol::Real = tol
4+
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
5+
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
76
)
87
98
Adds the pullback from the LQ decomposition of `A` to `ΔA` given the output `LQ` and
@@ -18,17 +17,16 @@ or rows exceed `gauge_atol`, a warning will be printed.
1817
"""
1918
function lq_pullback!(
2019
ΔA::AbstractMatrix, A, LQ, ΔLQ;
21-
tol::Real = default_pullback_gaugetol(LQ[1]),
22-
rank_atol::Real = tol,
23-
gauge_atol::Real = tol
20+
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
21+
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
2422
)
2523
# process
2624
L, Q = LQ
2725
m = size(L, 1)
2826
n = size(Q, 2)
2927
minmn = min(m, n)
3028
Ld = diagview(L)
31-
p = findlast(>=(rank_atol) abs, Ld)
29+
p = @something findlast(>=(rank_atol) abs, Ld) 0
3230

3331
ΔL, ΔQ = ΔLQ
3432

@@ -72,7 +70,7 @@ function lq_pullback!(
7270
# Q2' * ΔQ2 as a gauge dependent quantity.
7371
ΔQ2Q1ᴴ = ΔQ2 * Q1'
7472
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
75-
Δgauge < tol ||
73+
Δgauge < gauge_atol ||
7674
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
7775
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
7876
end
@@ -105,7 +103,10 @@ function lq_pullback!(
105103
end
106104

107105
"""
108-
lq_null_pullback(ΔA, A, Nᴴ, ΔNᴴ)
106+
lq_null_pullback!(
107+
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
108+
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
109+
)
109110
110111
Adds the pullback from the left nullspace of `A` to `ΔA`, given the nullspace basis
111112
`Nᴴ` and its cotangent `ΔNᴴ` of `lq_null(A)`.
@@ -114,13 +115,12 @@ See also [`lq_pullback!`](@ref).
114115
"""
115116
function lq_null_pullback!(
116117
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
117-
tol::Real = default_pullback_gaugetol(A),
118-
gauge_atol::Real = tol
118+
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
119119
)
120120
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
121121
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
122122
Δgauge = norm(aNᴴΔN)
123-
Δgauge < tol ||
123+
Δgauge < gauge_atol ||
124124
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
125125
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
126126
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')

src/pullbacks/qr.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
qr_pullback!(
33
ΔA, A, QR, ΔQR;
44
tol::Real = default_pullback_gaugetol(QR[2]),
5-
rank_atol::Real = tol,
6-
gauge_atol::Real = tol
5+
rank_atol::Real = default_pullback_rank_atol(QR[2]),
6+
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
77
)
88
99
Adds the pullback from the QR decomposition of `A` to `ΔA` given the output `QR` and
@@ -18,17 +18,16 @@ and also the adjoint variables `ΔQ` and `ΔR` should have nonzero values only i
1818
"""
1919
function qr_pullback!(
2020
ΔA::AbstractMatrix, A, QR, ΔQR;
21-
tol::Real = default_pullback_gaugetol(QR[2]),
22-
rank_atol::Real = tol,
23-
gauge_atol::Real = tol
21+
rank_atol::Real = default_pullback_rank_atol(QR[2]),
22+
gauge_atol::Real = default_pullback_gauge_atol(ΔQR[1])
2423
)
2524
# process
2625
Q, R = QR
2726
m = size(Q, 1)
2827
n = size(R, 2)
2928
minmn = min(m, n)
3029
Rd = diagview(R)
31-
p = findlast(>=(rank_atol) abs, Rd)
30+
p = @something findlast(>=(rank_atol) abs, Rd) 0
3231

3332
ΔQ, ΔR = ΔQR
3433

@@ -71,7 +70,7 @@ function qr_pullback!(
7170
# Q2' * ΔQ2 as a gauge dependent quantity.
7271
Q1dΔQ2 = Q1' * ΔQ2
7372
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
74-
Δgauge < tol ||
73+
Δgauge < gauge_atol ||
7574
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
7675
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
7776
end
@@ -104,7 +103,10 @@ function qr_pullback!(
104103
end
105104

106105
"""
107-
qr_null_pullback(ΔA, A, N, ΔN)
106+
qr_null_pullback!(
107+
ΔA::AbstractMatrix, A, N, ΔN;
108+
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
109+
)
108110
109111
Adds the pullback from the right nullspace of `A` to `ΔA`, given the nullspace basis
110112
`N` and its cotangent `ΔN` of `qr_null(A)`.
@@ -113,13 +115,12 @@ See also [`qr_pullback!`](@ref).
113115
"""
114116
function qr_null_pullback!(
115117
ΔA::AbstractMatrix, A, N, ΔN;
116-
tol::Real = default_pullback_gaugetol(A),
117-
gauge_atol::Real = tol
118+
gauge_atol::Real = default_pullback_gauge_atol(ΔN)
118119
)
119120
if !iszerotangent(ΔN) && size(N, 2) > 0
120121
aNᴴΔN = project_antihermitian!(N' * ΔN)
121122
Δgauge = norm(aNᴴΔN)
122-
Δgauge < tol ||
123+
Δgauge < gauge_atol ||
123124
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
124125
125126
Q, R = qr_compact(A; positive = true)

src/pullbacks/svd.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""
22
svd_pullback!(
33
ΔA, A, USVᴴ, ΔUSVᴴ, [ind];
4-
tol::Real=default_pullback_gaugetol(USVᴴ[2]),
5-
rank_atol::Real = tol,
6-
degeneracy_atol::Real = tol,
7-
gauge_atol::Real = tol
4+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
5+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
6+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
87
)
98
109
Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or
@@ -23,10 +22,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
2322
"""
2423
function svd_pullback!(
2524
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon();
26-
tol::Real = default_pullback_gaugetol(USVᴴ[2]),
27-
rank_atol::Real = tol,
28-
degeneracy_atol::Real = tol,
29-
gauge_atol::Real = tol
25+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
26+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
27+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
3028
)
3129

3230
# Extract the SVD components
@@ -106,10 +104,9 @@ end
106104
"""
107105
svd_trunc_pullback!(
108106
ΔA, A, USVᴴ, ΔUSVᴴ;
109-
tol::Real=default_pullback_gaugetol(S),
110-
rank_atol::Real = tol,
111-
degeneracy_atol::Real = tol,
112-
gauge_atol::Real = tol
107+
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
108+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
109+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
113110
)
114111
115112
Adds the pullback from the truncated SVD of `A` to `ΔA`, given the output `USVᴴ` and the
@@ -128,10 +125,9 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
128125
"""
129126
function svd_trunc_pullback!(
130127
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ;
131-
tol::Real = default_pullback_gaugetol(USVᴴ[2]),
132-
rank_atol::Real = tol,
133-
degeneracy_atol::Real = tol,
134-
gauge_atol::Real = tol
128+
rank_atol::Real = 0,
129+
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
130+
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
135131
)
136132
137133
# Extract the SVD components

0 commit comments

Comments
 (0)