Skip to content

Commit 50eb537

Browse files
authored
Loosen strictness on hermitian checks (#78)
* remove hermitian checks from yalapack wrappers * accept keywords in all eigh solvers * rework hermitian checks * non-allocating `ishermitian_approx` * code suggestions * refactor `check_hermitian` * rework tests * allocating path for GPU * small test fix * add missing imports
1 parent ea14b0c commit 50eb537

File tree

8 files changed

+147
-60
lines changed

8 files changed

+147
-60
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
7+
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
@@ -128,10 +128,17 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
128128
end
129129

130130
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
131-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
131+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
132+
all(A.diag .== adjoint(A.diag))
133+
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) =
134+
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
132135

133-
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A))
134-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
136+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) =
137+
all(A .== -adjoint(A))
138+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
139+
all(A.diag .== -adjoint(A.diag))
140+
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) =
141+
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
135142

136143
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
137144
axes(A) == axes(B) || throw(DimensionMismatch())

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
7+
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
@@ -134,11 +134,19 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::Stride
134134
return nothing
135135
end
136136

137-
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A))
138-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
139-
140-
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A))
141-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
137+
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
138+
all(A .== adjoint(A))
139+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
140+
all(A.diag .== adjoint(A.diag))
141+
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) =
142+
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
143+
144+
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) =
145+
all(A .== -adjoint(A))
146+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
147+
all(A.diag .== -adjoint(A.diag))
148+
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) =
149+
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
142150

143151
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
144152
axes(A) == axes(B) || throw(DimensionMismatch())

src/common/defaults.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,13 @@ function default_pullback_gaugetol(a)
1919
n = norm(a, Inf)
2020
return eps(eltype(n))^(3 / 4) * max(n, one(n))
2121
end
22+
23+
"""
24+
default_hermitian_tol(A)
25+
26+
Default tolerance for deciding to warn if the provided `A` is not hermitian.
27+
"""
28+
function default_hermitian_tol(A)
29+
n = norm(A, Inf)
30+
return eps(eltype(n))^(3 / 4) * max(n, one(n))
31+
end

src/common/matrixproperties.jl

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,31 +69,32 @@ is_right_isometric(A; kwargs...) = is_left_isometric(A'; kwargs...)
6969
Test whether a linear map is Hermitian, i.e. `A = A'`.
7070
The `isapprox_kwargs` can be used to control the tolerances of the equality.
7171
"""
72-
function ishermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
72+
function ishermitian(A; atol::Real = 0, rtol::Real = 0, kwargs...)
7373
if iszero(atol) && iszero(rtol)
7474
return ishermitian_exact(A; kwargs...)
7575
else
76-
return 2 * norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
76+
return ishermitian_approx(A; atol, rtol, kwargs...)
7777
end
7878
end
79-
function ishermitian_exact(A)
80-
return A == A'
81-
end
82-
function ishermitian_exact(A::StridedMatrix; kwargs...)
83-
return strided_ishermitian_exact(A, Val(false); kwargs...)
79+
80+
ishermitian_exact(A) = A == A'
81+
ishermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(false); kwargs...)
82+
function ishermitian_approx(A; atol, rtol, kwargs...)
83+
return norm(project_antihermitian(A; kwargs...)) max(atol, rtol * norm(A))
8484
end
85+
ishermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(false); kwargs...)
8586

8687
"""
8788
isantihermitian(A; isapprox_kwargs...)
8889
8990
Test whether a linear map is anti-Hermitian, i.e. `A = -A'`.
9091
The `isapprox_kwargs` can be used to control the tolerances of the equality.
9192
"""
92-
function isantihermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
93+
function isantihermitian(A; atol::Real = 0, rtol::Real = 0, kwargs...)
9394
if iszero(atol) && iszero(rtol)
9495
return isantihermitian_exact(A; kwargs...)
9596
else
96-
return 2 * norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
97+
return isantihermitian_approx(A; atol, rtol, kwargs...)
9798
end
9899
end
99100
function isantihermitian_exact(A)
@@ -102,6 +103,10 @@ end
102103
function isantihermitian_exact(A::StridedMatrix; kwargs...)
103104
return strided_ishermitian_exact(A, Val(true); kwargs...)
104105
end
106+
function isantihermitian_approx(A; atol, rtol, kwargs...)
107+
return norm(project_hermitian(A; kwargs...)) max(atol, rtol * norm(A))
108+
end
109+
isantihermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(true); kwargs...)
105110

106111
# blocked implementation of exact checks for strided matrices
107112
# -----------------------------------------------------------
@@ -139,3 +144,51 @@ function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti}
139144
end
140145
return true
141146
end
147+
148+
149+
function strided_ishermitian_approx(
150+
A::AbstractMatrix, anti::Val;
151+
blocksize = 32, atol::Real = default_hermitian_tol(A), rtol::Real = 0
152+
)
153+
n = size(A, 1)
154+
ϵ² = abs2(zero(eltype(A)))
155+
ϵ²max = oftype(ϵ², rtol > 0 ? max(atol, rtol * norm(A)) : atol)^2
156+
for j in 1:blocksize:n
157+
jb = min(blocksize, n - j + 1)
158+
ϵ² += _ishermitian_approx_diag(view(A, j:(j + jb - 1), j:(j + jb - 1)), anti)
159+
ϵ² < ϵ²max || return false
160+
for i in 1:blocksize:(j - 1)
161+
ib = blocksize
162+
ϵ² += 2 * _ishermitian_approx_offdiag(
163+
view(A, i:(i + ib - 1), j:(j + jb - 1)),
164+
view(A, j:(j + jb - 1), i:(i + ib - 1)),
165+
anti
166+
)
167+
ϵ² < ϵ²max || return false
168+
end
169+
end
170+
return true
171+
end
172+
173+
function _ishermitian_approx_diag(A, ::Val{anti}) where {anti}
174+
n = size(A, 1)
175+
ϵ² = abs2(zero(eltype(A)))
176+
@inbounds for j in 1:n
177+
@simd for i in 1:j
178+
val = _project_hermitian(A[i, j], A[j, i], !anti)
179+
ϵ² += abs2(val) * (1 + Int(i < j))
180+
end
181+
end
182+
return ϵ²
183+
end
184+
function _ishermitian_approx_offdiag(Al, Au, ::Val{anti}) where {anti}
185+
m, n = size(Al) # == reverse(size(Al))
186+
ϵ² = abs2(zero(eltype(Al)))
187+
@inbounds for j in 1:n
188+
@simd for i in 1:m
189+
val = _project_hermitian(Al[i, j], Au[j, i], !anti)
190+
ϵ² += abs2(val)
191+
end
192+
end
193+
return ϵ²
194+
end

src/implementations/eigh.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,40 @@ copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
88

99
copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)
1010

11-
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
11+
check_hermitian(A, ::AbstractAlgorithm) = check_hermitian(A)
12+
check_hermitian(A, alg::Algorithm) = check_hermitian(A; atol = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A)))
13+
function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = 0)
1214
m, n = size(A)
1315
m == n || throw(DimensionMismatch("square input matrix expected"))
16+
ishermitian(A; atol, rtol) ||
17+
throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix."))
18+
return nothing
19+
end
20+
21+
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
22+
check_hermitian(A, alg)
1423
D, V = DV
24+
m = size(A, 1)
1525
@assert D isa Diagonal && V isa AbstractMatrix
1626
@check_size(D, (m, m))
1727
@check_scalar(D, A, real)
1828
@check_size(V, (m, m))
1929
@check_scalar(V, A)
2030
return nothing
2131
end
22-
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
23-
m, n = size(A)
24-
m == n || throw(DimensionMismatch("square input matrix expected"))
32+
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::AbstractAlgorithm)
33+
check_hermitian(A, alg)
34+
m = size(A, 1)
2535
@assert D isa AbstractVector
26-
@check_size(D, (n,))
36+
@check_size(D, (m,))
2737
@check_scalar(D, A, real)
2838
return nothing
2939
end
3040

31-
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
32-
m, n = size(A)
33-
@assert m == n && isdiag(A)
34-
@assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A)
41+
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalAlgorithm)
42+
check_hermitian(A, alg)
43+
@assert isdiag(A)
44+
m = size(A, 1)
3545
D, V = DV
3646
@assert D isa Diagonal && V isa Diagonal
3747
@check_size(D, (m, m))
@@ -40,12 +50,12 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgo
4050
@check_scalar(V, A)
4151
return nothing
4252
end
43-
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
44-
m, n = size(A)
45-
@assert m == n && isdiag(A)
46-
@assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A)
53+
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
54+
check_hermitian(A, alg)
55+
@assert isdiag(A)
56+
m = size(A, 1)
4757
@assert D isa AbstractVector
48-
@check_size(D, (n,))
58+
@check_size(D, (m,))
4959
@check_scalar(D, A, real)
5060
return nothing
5161
end

src/implementations/projections.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,16 @@ function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::V
8585
return B
8686
end
8787

88+
@inline function _project_hermitian(Aij::Number, Aji::Number, anti::Bool)
89+
return anti ? (Aij - Aji') / 2 : (Aij + Aji') / 2
90+
end
8891
function _project_hermitian_offdiag!(
8992
Au::AbstractMatrix, Al::AbstractMatrix, Bu::AbstractMatrix, Bl::AbstractMatrix, ::Val{anti}
9093
) where {anti}
91-
9294
m, n = size(Au) # == reverse(size(Au))
9395
return @inbounds for j in 1:n
9496
@simd for i in 1:m
95-
val = anti ? (Au[i, j] - adjoint(Al[j, i])) / 2 : (Au[i, j] + adjoint(Al[j, i])) / 2
97+
val = _project_hermitian(Au[i, j], Al[j, i], anti)
9698
Bu[i, j] = val
9799
aval = adjoint(val)
98100
Bl[j, i] = anti ? -aval : aval
@@ -104,7 +106,7 @@ function _project_hermitian_diag!(A::AbstractMatrix, B::AbstractMatrix, ::Val{an
104106
n = size(A, 1)
105107
@inbounds for j in 1:n
106108
@simd for i in 1:(j - 1)
107-
val = anti ? (A[i, j] - adjoint(A[j, i])) / 2 : (A[i, j] + adjoint(A[j, i])) / 2
109+
val = _project_hermitian(A[i, j], A[j, i], anti)
108110
B[i, j] = val
109111
aval = adjoint(val)
110112
B[j, i] = anti ? -aval : aval

src/yalapack.jl

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module YALAPACK # Yet another lapack wrapper
1010

1111
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK,
1212
LAPACKException, SingularException, PosDefException, checksquare, chkstride1,
13-
require_one_based_indexing, triu!, issymmetric, ishermitian, isposdef, adjoint!
13+
require_one_based_indexing, triu!, isposdef, adjoint!
1414

1515
using LinearAlgebra.BLAS: @blasfunc, libblastrampoline
1616
using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror
@@ -984,16 +984,12 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
984984
A::AbstractMatrix{$elty},
985985
W::AbstractVector{$relty} = similar(A, $relty, size(A, 1)),
986986
V::AbstractMatrix{$elty} = A;
987-
uplo::AbstractChar = 'U'
987+
uplo::AbstractChar = 'U',
988+
kwargs...
988989
) # shouldn't matter but 'U' seems slightly faster than 'L'
989990
require_one_based_indexing(A, V, W)
990991
chkstride1(A, V, W)
991992
n = checksquare(A)
992-
if $elty <: Real
993-
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
994-
else
995-
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
996-
end
997993
chkuplofinite(A, uplo)
998994
n == length(W) || throw(DimensionMismatch("length mismatch between A and W"))
999995
if length(V) == 0
@@ -1063,11 +1059,6 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
10631059
require_one_based_indexing(A, V, W)
10641060
chkstride1(A, V, W)
10651061
n = checksquare(A)
1066-
if $elty <: Real
1067-
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
1068-
else
1069-
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
1070-
end
10711062
chkuplofinite(A, uplo)
10721063
if haskey(kwargs, :irange)
10731064
il = first(kwargs[:irange])
@@ -1175,11 +1166,6 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
11751166
require_one_based_indexing(A, V, W)
11761167
chkstride1(A, V, W)
11771168
n = checksquare(A)
1178-
if $elty <: Real
1179-
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
1180-
else
1181-
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
1182-
end
11831169
chkuplofinite(A, uplo)
11841170
if haskey(kwargs, :irange)
11851171
il = first(irange)
@@ -1289,16 +1275,12 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
12891275
A::AbstractMatrix{$elty},
12901276
W::AbstractVector{$relty} = similar(A, $relty, size(A, 1)),
12911277
V::AbstractMatrix{$elty} = A;
1292-
uplo::AbstractChar = 'U'
1278+
uplo::AbstractChar = 'U',
1279+
kwargs...
12931280
) # shouldn't matter but 'U' seems slightly faster than 'L'
12941281
require_one_based_indexing(A, V, W)
12951282
chkstride1(A, V, W)
12961283
n = checksquare(A)
1297-
if $elty <: Real
1298-
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
1299-
else
1300-
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
1301-
end
13021284
uplo = 'U' # shouldn't matter but 'U' seems slightly faster than 'L'
13031285
chkuplofinite(A, uplo)
13041286
n == length(W) || throw(DimensionMismatch("length mismatch between A and W"))

test/projections.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, Diagonal, norm
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm, normalize!
66

77
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
88

@@ -43,6 +43,21 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
4343
@test isantihermitian(Ba)
4444
@test Ba Aa
4545
end
46+
47+
# test approximate error calculation
48+
A = normalize!(randn(rng, T, m, m))
49+
Ah = project_hermitian(A)
50+
Aa = project_antihermitian(A)
51+
52+
Ah_approx = Ah + noisefactor * Aa
53+
ϵ = norm(project_antihermitian(Ah_approx))
54+
@test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ)
55+
@test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ)
56+
57+
Aa_approx = Aa + noisefactor * Ah
58+
ϵ = norm(project_hermitian(Aa_approx))
59+
@test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ)
60+
@test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ)
4661
end
4762

4863
@testset "project_isometric! for T = $T" for T in BLASFloats

0 commit comments

Comments
 (0)