Skip to content

Commit 879dc88

Browse files
committed
type stability of matrix functions
1 parent 91ef00e commit 879dc88

File tree

5 files changed

+104
-48
lines changed

5 files changed

+104
-48
lines changed

src/dense.jl

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ function schurpow(A::AbstractMatrix, p)
592592
end
593593

594594
# if A has nonpositive real eigenvalues, retmat is a nonprincipal matrix power.
595-
if isreal(retmat)
595+
if eltype(A) <: Real && isreal(retmat)
596596
return real(retmat)
597597
else
598598
return retmat
@@ -602,20 +602,19 @@ function (^)(A::AbstractMatrix{T}, p::Real) where T
602602
checksquare(A)
603603
# Quicker return if A is diagonal
604604
if isdiag(A)
605-
TT = promote_op(^, T, typeof(p))
606-
retmat = copymutable_oftype(A, TT)
607-
for i in diagind(retmat, IndexStyle(retmat))
608-
retmat[i] = retmat[i] ^ p
605+
if T <: Real && any(<(0), diagview(A))
606+
return applydiagonal(x -> complex(x)^p, A)
607+
else
608+
return applydiagonal(x -> x^p, A)
609609
end
610-
return retmat
611610
end
612611

613612
# For integer powers, use power_by_squaring
614613
isinteger(p) && return integerpow(A, p)
615614

616615
# If possible, use diagonalization
617616
if ishermitian(A)
618-
return (Hermitian(A)^p)
617+
return parent(Hermitian(A)^p)
619618
end
620619

621620
# Otherwise, use Schur decomposition
@@ -745,7 +744,7 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
745744
end
746745
return A
747746
elseif ishermitian(A)
748-
return copytri!(parent(exp(Hermitian(A))), 'U', true)
747+
return parent(exp(Hermitian(A)))
749748
end
750749
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
751750
nA = opnorm(A, 1)
@@ -918,10 +917,10 @@ julia> log(A)
918917
function log(A::AbstractMatrix)
919918
# If possible, use diagonalization
920919
if isdiag(A) && eltype(A) <: Union{Real,Complex}
921-
if eltype(A) <: Real && all(>=(0), diagview(A))
922-
return applydiagonal(log, A)
920+
if eltype(A) <: Real && any(<(0), diagview(A))
921+
return applydiagonal(log complex, A)
923922
else
924-
return applydiagonal(logcomplex, A)
923+
return applydiagonal(log, A)
925924
end
926925
elseif ishermitian(A)
927926
logHermA = log(Hermitian(A))
@@ -1004,13 +1003,14 @@ sqrt(::AbstractMatrix)
10041003
function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}}
10051004
if checksquare(A) == 0
10061005
return copy(float(A))
1007-
elseif isdiag(A) && (T <: Complex || all(x -> x zero(x), diagview(A)))
1008-
# Real Diagonal sqrt requires each diagonal element to be positive
1009-
return applydiagonal(sqrt, A)
1006+
elseif isdiag(A)
1007+
if T <: Real && any(<(0), diagview(A))
1008+
return applydiagonal(sqrt complex, A)
1009+
else
1010+
return applydiagonal(sqrt, A)
1011+
end
10101012
elseif ishermitian(A)
1011-
sqrtHermA = sqrt(Hermitian(A))
1012-
PS = parent(sqrtHermA)
1013-
return ishermitian(sqrtHermA) ? copytri_maybe_inplace(PS, 'U', true) : PS
1013+
return parent(sqrt(Hermitian(A)))
10141014
elseif istriu(A)
10151015
return triu!(parent(sqrt(UpperTriangular(A))))
10161016
elseif isreal(A)
@@ -1044,7 +1044,7 @@ sqrt(A::TransposeAbsMat) = transpose(sqrt(parent(A)))
10441044
Computes the real-valued cube root of a real-valued matrix `A`. If `T = cbrt(A)`, then
10451045
we have `T*T*T ≈ A`, see example given below.
10461046
1047-
If `A` is symmetric, i.e., of type `HermOrSym{<:Real}`, then ([`eigen`](@ref)) is used to
1047+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
10481048
find the cube root. Otherwise, a specialized version of the p-th root algorithm [^S03] is
10491049
utilized, which exploits the real-valued Schur decomposition ([`schur`](@ref))
10501050
to compute the cube root.
@@ -1077,7 +1077,7 @@ function cbrt(A::AbstractMatrix{<:Real})
10771077
elseif isdiag(A)
10781078
return applydiagonal(cbrt, A)
10791079
elseif issymmetric(A)
1080-
return cbrt(Symmetric(A, :U))
1080+
return copytri!(parent(cbrt(Symmetric(A))), 'U')
10811081
else
10821082
S = schur(A)
10831083
return S.Z * _cbrt_quasi_triu!(S.T) * S.Z'
@@ -1118,7 +1118,7 @@ end
11181118
11191119
Compute the matrix cosine of a square matrix `A`.
11201120
1121-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1121+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
11221122
compute the cosine. Otherwise, the cosine is determined by calling [`exp`](@ref).
11231123
11241124
# Examples
@@ -1160,7 +1160,7 @@ end
11601160
11611161
Compute the matrix sine of a square matrix `A`.
11621162
1163-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1163+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
11641164
compute the sine. Otherwise, the sine is determined by calling [`exp`](@ref).
11651165
11661166
# Examples
@@ -1265,7 +1265,7 @@ end
12651265
12661266
Compute the matrix tangent of a square matrix `A`.
12671267
1268-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1268+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
12691269
compute the tangent. Otherwise, the tangent is determined by calling [`exp`](@ref).
12701270
12711271
# Examples
@@ -1357,7 +1357,7 @@ _subadd!!(X, Y) = X - Y, X + Y
13571357
13581358
Compute the inverse matrix cosine of a square matrix `A`.
13591359
1360-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1360+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
13611361
compute the inverse cosine. Otherwise, the inverse cosine is determined by using
13621362
[`log`](@ref) and [`sqrt`](@ref). For the theory and logarithmic formulas used to compute
13631363
this function, see [^AH16_1].
@@ -1391,7 +1391,7 @@ end
13911391
13921392
Compute the inverse matrix sine of a square matrix `A`.
13931393
1394-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1394+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
13951395
compute the inverse sine. Otherwise, the inverse sine is determined by using [`log`](@ref)
13961396
and [`sqrt`](@ref). For the theory and logarithmic formulas used to compute this function,
13971397
see [^AH16_2].
@@ -1425,7 +1425,7 @@ end
14251425
14261426
Compute the inverse matrix tangent of a square matrix `A`.
14271427
1428-
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
1428+
If `A` is real-symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is used to
14291429
compute the inverse tangent. Otherwise, the inverse tangent is determined by using
14301430
[`log`](@ref). For the theory and logarithmic formulas used to compute this function, see
14311431
[^AH16_3].
@@ -1436,8 +1436,8 @@ compute the inverse tangent. Otherwise, the inverse tangent is determined by usi
14361436
```julia-repl
14371437
julia> atan(tan([0.5 0.1; -0.2 0.3]))
14381438
2×2 Matrix{ComplexF64}:
1439-
0.5+1.38778e-17im 0.1-2.77556e-17im
1440-
-0.2+6.93889e-17im 0.3-4.16334e-17im
1439+
0.5 0.1
1440+
-0.2 0.3
14411441
```
14421442
"""
14431443
function atan(A::AbstractMatrix)
@@ -1450,7 +1450,12 @@ function atan(A::AbstractMatrix)
14501450
SchurF = Schur{Complex}(schur(A))
14511451
U = im * UpperTriangular(SchurF.T)
14521452
R = triu!(parent(log((I + U) / (I - U)) / 2im))
1453-
return SchurF.Z * R * SchurF.Z'
1453+
retmat = SchurF.Z * R * SchurF.Z'
1454+
if eltype(A) <: Real
1455+
return real(retmat)
1456+
else
1457+
return retmat
1458+
end
14541459
end
14551460

14561461
"""
@@ -1493,7 +1498,12 @@ function asinh(A::AbstractMatrix)
14931498
SchurF = Schur{Complex}(schur(A))
14941499
U = UpperTriangular(SchurF.T)
14951500
R = triu!(parent(log(U + sqrt(I + U^2))))
1496-
return SchurF.Z * R * SchurF.Z'
1501+
retmat = SchurF.Z * R * SchurF.Z'
1502+
if eltype(A) <: Real
1503+
return real(retmat)
1504+
else
1505+
return retmat
1506+
end
14971507
end
14981508

14991509
"""

src/symmetric.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ function ^(A::SymSymTri{<:Complex}, p::Real)
867867
return Symmetric(schurpow(A, p))
868868
end
869869

870-
for func in (:cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
870+
for func in (:cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :cbrt)
871871
@eval begin
872872
function ($func)(A::SelfAdjoint)
873873
F = eigen(A)
@@ -890,7 +890,7 @@ function cis(A::SelfAdjoint)
890890
return nonhermitianwrappertype(A)(retmat)
891891
end
892892

893-
for func in (:acos, :asin)
893+
for func in (:acos, :asin, :atanh)
894894
@eval begin
895895
function ($func)(A::SelfAdjoint)
896896
F = eigen(A)

src/triangular.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,7 @@ function powm!(A0::UpperTriangular, p::Real)
19961996
A[i, i] = -A[i, i]
19971997
end
19981998
# Compute the Padé approximant
1999-
c = 0.5 * (p - m) / (2 * m - 1)
1999+
c = (p - m) / (4 * m - 2)
20002000
triu!(A)
20012001
S = c * A
20022002
Stmp = similar(S)

test/dense.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,8 @@ end
639639
sinA1 = convert(Matrix{elty}, [0.2865568596627417 -1.107751980582015 -0.13772915374386513;
640640
-0.6227405671629401 0.2176922827908092 -0.5538759902910078;
641641
-0.6227405671629398 -0.6916051440348725 0.3554214365346742])
642-
@test @inferred(cos(A1)) cosA1
643-
@test @inferred(sin(A1)) sinA1
642+
@test cos(A1) cosA1
643+
@test sin(A1) sinA1
644644

645645
cosA2 = convert(Matrix{elty}, [-0.6331745163802187 0.12878366262380136 -0.17304181968301532;
646646
0.12878366262380136 -0.5596234510748788 0.5210483146041339;
@@ -663,22 +663,22 @@ end
663663

664664
# Identities
665665
for (i, A) in enumerate((A1, A2, A3, A4, A5))
666-
@test @inferred(sincos(A)) == (sin(A), cos(A))
666+
@test sincos(A) == (sin(A), cos(A))
667667
@test cos(A)^2 + sin(A)^2 Matrix(I, size(A))
668668
@test cos(A) cos(-A)
669669
@test sin(A) -sin(-A)
670-
@test @inferred(tan(A)) sin(A) / cos(A)
670+
@test tan(A) sin(A) / cos(A)
671671

672672
@test cos(A) real(exp(im*A))
673673
@test sin(A) imag(exp(im*A))
674674
@test cos(A) real(cis(A))
675675
@test sin(A) imag(cis(A))
676-
@test @inferred(cis(A)) cos(A) + im * sin(A)
676+
@test cis(A) cos(A) + im * sin(A)
677677

678-
@test @inferred(cosh(A)) 0.5 * (exp(A) + exp(-A))
679-
@test @inferred(sinh(A)) 0.5 * (exp(A) - exp(-A))
680-
@test @inferred(cosh(A)) cosh(-A)
681-
@test @inferred(sinh(A)) -sinh(-A)
678+
@test cosh(A) 0.5 * (exp(A) + exp(-A))
679+
@test sinh(A) 0.5 * (exp(A) - exp(-A))
680+
@test cosh(A) cosh(-A)
681+
@test sinh(A) -sinh(-A)
682682

683683
# Some of the following identities fail for A3, A4 because the matrices are singular
684684
if i in (1, 2, 5)
@@ -687,7 +687,7 @@ end
687687
@test @inferred(cot(A)) inv(tan(A))
688688
@test @inferred(sech(A)) inv(cosh(A))
689689
@test @inferred(csch(A)) inv(sinh(A))
690-
@test @inferred(coth(A)) inv(@inferred tanh(A))
690+
@test @inferred(coth(A)) inv(tanh(A))
691691
end
692692
# The following identities fail for A1, A2 due to rounding errors;
693693
# probably needs better algorithm for the general case
@@ -904,11 +904,6 @@ end
904904
end
905905
end
906906

907-
@testset "matrix logarithm is type-inferable" for elty in (Float32,Float64,ComplexF32,ComplexF64)
908-
A1 = randn(elty, 4, 4)
909-
@inferred Union{Matrix{elty},Matrix{complex(elty)}} log(A1)
910-
end
911-
912907
@testset "Additional matrix square root tests" for elty in (Float64, ComplexF64)
913908
A11 = convert(Matrix{elty}, [3 2; -5 -3])
914909
@test sqrt(A11)^2 A11

test/symmetric.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,13 +1199,64 @@ end
11991199
end
12001200
end
12011201

1202-
@testset "asin/acos/acosh for matrix outside the real domain" begin
1202+
@testset "asin/acos/acosh/tanh for matrix outside the real domain" begin
12031203
M = [0 2;2 0] #eigenvalues are ±2
12041204
for T (Float32, Float64, ComplexF32, ComplexF64)
12051205
M2 = Hermitian(T.(M))
12061206
@test sin(asin(M2)) M2
12071207
@test cos(acos(M2)) M2
12081208
@test cosh(acosh(M2)) M2
1209+
@test tanh(atanh(M2)) M2
1210+
end
1211+
end
1212+
1213+
@testset "type inference of matrix functions" begin
1214+
for T (Float32, Float64, ComplexF32, ComplexF64)
1215+
a = randn(T, 2, 2)
1216+
syma = Symmetric(a)
1217+
symtria = SymTridiagonal(syma)
1218+
herma = Hermitian(a)
1219+
#nasty functions
1220+
for f in (x->x^real(T)(0.3), sqrt, log, asin, acos, acosh, atanh)
1221+
if T <: Real
1222+
@test @inferred Matrix{Complex{T}} f(a) isa Matrix
1223+
@test @inferred Symmetric{Complex{T}} f(syma) isa Symmetric
1224+
@test @inferred Symmetric{Complex{T}} f(symtria) isa Symmetric
1225+
@test @inferred Symmetric{Complex{T}} f(herma) isa Union{Symmetric{Complex{T}}, Hermitian{T}}
1226+
else
1227+
@test @inferred f(a) isa Matrix{T}
1228+
@test @inferred Matrix{T} f(herma) isa Union{Matrix{T}, Hermitian{T}}
1229+
end
1230+
end
1231+
#nice functions
1232+
for f in (x->x^2, exp, cos, sin, tan, cosh, sinh, tanh, atan, asinh, cbrt)
1233+
if T <: Real
1234+
@test @inferred f(a) isa Matrix{T}
1235+
@test @inferred f(syma) isa Symmetric{T}
1236+
@test @inferred f(symtria) isa Symmetric{T}
1237+
@test @inferred f(herma) isa Hermitian{T}
1238+
else
1239+
f != cbrt && @test @inferred f(a) isa Matrix{T}
1240+
@test @inferred f(herma) isa Hermitian{T}
1241+
end
1242+
end
1243+
#special case cis
1244+
if T <: Real
1245+
@test @inferred cis(a) isa Matrix{Complex{T}}
1246+
@test @inferred cis(syma) isa Symmetric{Complex{T}}
1247+
@test @inferred cis(symtria) isa Symmetric{Complex{T}}
1248+
@test @inferred cis(herma) isa Symmetric{Complex{T}}
1249+
else
1250+
@test @inferred cis(a) isa Matrix{T}
1251+
@test @inferred cis(herma) isa Matrix{T}
1252+
end
1253+
#special case sincos
1254+
if T <: Real
1255+
@test @inferred sincos(syma) isa Tuple{Symmetric{T}, Symmetric{T}}
1256+
@test @inferred sincos(symtria) isa Tuple{Symmetric{T}, Symmetric{T}}
1257+
end
1258+
@test @inferred sincos(a) isa Tuple{Matrix{T}, Matrix{T}}
1259+
@test @inferred sincos(herma) isa Tuple{Hermitian{T}, Hermitian{T}}
12091260
end
12101261
end
12111262

0 commit comments

Comments
 (0)