Skip to content

Commit af22843

Browse files
authored
stabilize eig(h) tests (#122)
* stabilize eig(h) tests * simplify/clarify code * some random fix
1 parent 8bded8b commit af22843

File tree

4 files changed

+37
-8
lines changed

4 files changed

+37
-8
lines changed

src/common/matrixproperties.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function is_left_isometric(A::AbstractMatrix; atol::Real = 0, rtol::Real = defau
5050
P = A' * A
5151
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
5252
diagview(P) .-= 1
53-
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
53+
return norm(P) max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
5454
end
5555
5656
@doc """

test/ad_utils.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,36 @@ function remove_eighgauge_dependence!(
2727
mul!(ΔV, V, gaugepart, -1, 1)
2828
return ΔV
2929
end
30+
function stabilize_eigvals!(D::AbstractVector)
31+
absD = abs.(D)
32+
p = invperm(sortperm(absD)) # rank of abs(D)
33+
# account for exact degeneracies in absolute value when having complex conjugate pairs
34+
for i in 1:(length(D) - 1)
35+
if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially
36+
p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones
37+
end
38+
end
39+
n = maximum(p)
40+
# rescale eigenvalues so that they lie on distinct radii in the complex plane
41+
# that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n
42+
radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
43+
for i in 1:length(D)
44+
D[i] = sign(D[i]) * radii[p[i]]
45+
end
46+
return D
47+
end
48+
function make_eig_matrix(rng, T, n)
49+
A = randn(rng, T, n, n)
50+
D, V = eig_full(A)
51+
stabilize_eigvals!(diagview(D))
52+
Ac = V * D * inv(V)
53+
return (T <: Real) ? real(Ac) : Ac
54+
end
55+
function make_eigh_matrix(rng, T, n)
56+
A = project_hermitian!(randn(rng, T, n, n))
57+
D, V = eigh_full(A)
58+
stabilize_eigvals!(diagview(D))
59+
return project_hermitian!(V * D * V')
60+
end
3061

3162
precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T)))

test/chainrules.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ end
221221
rng = StableRNG(12345)
222222
m = 19
223223
atol = rtol = m * m * precision(T)
224-
A = randn(rng, T, m, m)
224+
A = make_eig_matrix(rng, T, m)
225225
D, V = eig_full(A)
226226
Ddiag = diagview(D)
227227
ΔV = randn(rng, complex(T), m, m)
@@ -297,8 +297,7 @@ end
297297
rng = StableRNG(12345)
298298
m = 19
299299
atol = rtol = m * m * precision(T)
300-
A = randn(rng, T, m, m)
301-
A = A + A'
300+
A = make_eigh_matrix(rng, T, m)
302301
D, V = eigh_full(A)
303302
Ddiag = diagview(D)
304303
ΔV = randn(rng, T, m, m)

test/mooncake.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ end
238238
rng = StableRNG(12345)
239239
m = 19
240240
atol = rtol = m * m * precision(T)
241-
A = randn(rng, T, m, m)
241+
A = make_eig_matrix(rng, T, m)
242242
DV = eig_full(A)
243243
D, V = DV
244244
Ddiag = diagview(D)
@@ -347,17 +347,16 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb
347347
rng = StableRNG(12345)
348348
m = 19
349349
atol = rtol = m * m * precision(T)
350-
A = randn(rng, T, m, m)
351-
A = A + A'
350+
A = make_eigh_matrix(rng, T, m)
352351
D, V = eigh_full(A)
352+
Ddiag = diagview(D)
353353
ΔV = randn(rng, T, m, m)
354354
ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol)
355355
ΔD = randn(rng, real(T), m, m)
356356
ΔD2 = Diagonal(randn(rng, real(T), m))
357357
dD = make_mooncake_tangent(ΔD2)
358358
dV = make_mooncake_tangent(ΔV)
359359
dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV)
360-
Ddiag = diagview(D)
361360
@testset for alg in (
362361
LAPACK_QRIteration(),
363362
#LAPACK_DivideAndConquer(),

0 commit comments

Comments
 (0)