Skip to content

Commit b6a2cc1

Browse files
jishnubdkarrasch
andauthored
Check isdiag in dense trig functions (JuliaLang#56483)
This improves performance for dense diagonal matrices, as we may apply the function only to the diagonal elements. ```julia julia> A = diagm(0=>rand(100)); julia> @Btime cos($A); 349.211 μs (22 allocations: 401.58 KiB) # nightly v"1.12.0-DEV.1571" 16.215 μs (7 allocations: 80.02 KiB) # this PR ``` --------- Co-authored-by: Daniel Karrasch <[email protected]>
1 parent 88201cf commit b6a2cc1

File tree

2 files changed

+75
-34
lines changed

2 files changed

+75
-34
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,12 @@ Base.:^(::Irrational{:ℯ}, A::AbstractMatrix) = exp(A)
683683
## "Functions of Matrices: Theory and Computation", SIAM
684684
function exp!(A::StridedMatrix{T}) where T<:BlasFloat
685685
n = checksquare(A)
686-
if ishermitian(A)
686+
if isdiag(A)
687+
for i in diagind(A, IndexStyle(A))
688+
A[i] = exp(A[i])
689+
end
690+
return A
691+
elseif ishermitian(A)
687692
return copytri!(parent(exp(Hermitian(A))), 'U', true)
688693
end
689694
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
@@ -1014,9 +1019,16 @@ end
10141019
cbrt(A::AdjointAbsMat) = adjoint(cbrt(parent(A)))
10151020
cbrt(A::TransposeAbsMat) = transpose(cbrt(parent(A)))
10161021

1022+
function applydiagonal(f, A)
1023+
dinv = f(Diagonal(A))
1024+
copyto!(similar(A, eltype(dinv)), dinv)
1025+
end
1026+
10171027
function inv(A::StridedMatrix{T}) where T
10181028
checksquare(A)
1019-
if istriu(A)
1029+
if isdiag(A)
1030+
Ai = applydiagonal(inv, A)
1031+
elseif istriu(A)
10201032
Ai = triu!(parent(inv(UpperTriangular(A))))
10211033
elseif istril(A)
10221034
Ai = tril!(parent(inv(LowerTriangular(A))))
@@ -1044,14 +1056,18 @@ julia> cos(fill(1.0, (2,2)))
10441056
```
10451057
"""
10461058
function cos(A::AbstractMatrix{<:Real})
1047-
if issymmetric(A)
1059+
if isdiag(A)
1060+
return applydiagonal(cos, A)
1061+
elseif issymmetric(A)
10481062
return copytri!(parent(cos(Symmetric(A))), 'U')
10491063
end
10501064
T = complex(float(eltype(A)))
10511065
return real(exp!(T.(im .* A)))
10521066
end
10531067
function cos(A::AbstractMatrix{<:Complex})
1054-
if ishermitian(A)
1068+
if isdiag(A)
1069+
return applydiagonal(cos, A)
1070+
elseif ishermitian(A)
10551071
return copytri!(parent(cos(Hermitian(A))), 'U', true)
10561072
end
10571073
T = complex(float(eltype(A)))
@@ -1077,14 +1093,18 @@ julia> sin(fill(1.0, (2,2)))
10771093
```
10781094
"""
10791095
function sin(A::AbstractMatrix{<:Real})
1080-
if issymmetric(A)
1096+
if isdiag(A)
1097+
return applydiagonal(sin, A)
1098+
elseif issymmetric(A)
10811099
return copytri!(parent(sin(Symmetric(A))), 'U')
10821100
end
10831101
T = complex(float(eltype(A)))
10841102
return imag(exp!(T.(im .* A)))
10851103
end
10861104
function sin(A::AbstractMatrix{<:Complex})
1087-
if ishermitian(A)
1105+
if isdiag(A)
1106+
return applydiagonal(sin, A)
1107+
elseif ishermitian(A)
10881108
return copytri!(parent(sin(Hermitian(A))), 'U', true)
10891109
end
10901110
T = complex(float(eltype(A)))
@@ -1163,7 +1183,9 @@ julia> tan(fill(1.0, (2,2)))
11631183
```
11641184
"""
11651185
function tan(A::AbstractMatrix)
1166-
if ishermitian(A)
1186+
if isdiag(A)
1187+
return applydiagonal(tan, A)
1188+
elseif ishermitian(A)
11671189
return copytri!(parent(tan(Hermitian(A))), 'U', true)
11681190
end
11691191
S, C = sincos(A)
@@ -1177,7 +1199,9 @@ end
11771199
Compute the matrix hyperbolic cosine of a square matrix `A`.
11781200
"""
11791201
function cosh(A::AbstractMatrix)
1180-
if ishermitian(A)
1202+
if isdiag(A)
1203+
return applydiagonal(cosh, A)
1204+
elseif ishermitian(A)
11811205
return copytri!(parent(cosh(Hermitian(A))), 'U', true)
11821206
end
11831207
X = exp(A)
@@ -1191,7 +1215,9 @@ end
11911215
Compute the matrix hyperbolic sine of a square matrix `A`.
11921216
"""
11931217
function sinh(A::AbstractMatrix)
1194-
if ishermitian(A)
1218+
if isdiag(A)
1219+
return applydiagonal(sinh, A)
1220+
elseif ishermitian(A)
11951221
return copytri!(parent(sinh(Hermitian(A))), 'U', true)
11961222
end
11971223
X = exp(A)
@@ -1205,7 +1231,9 @@ end
12051231
Compute the matrix hyperbolic tangent of a square matrix `A`.
12061232
"""
12071233
function tanh(A::AbstractMatrix)
1208-
if ishermitian(A)
1234+
if isdiag(A)
1235+
return applydiagonal(tanh, A)
1236+
elseif ishermitian(A)
12091237
return copytri!(parent(tanh(Hermitian(A))), 'U', true)
12101238
end
12111239
X = exp(A)
@@ -1240,7 +1268,9 @@ julia> acos(cos([0.5 0.1; -0.2 0.3]))
12401268
```
12411269
"""
12421270
function acos(A::AbstractMatrix)
1243-
if ishermitian(A)
1271+
if isdiag(A)
1272+
return applydiagonal(acos, A)
1273+
elseif ishermitian(A)
12441274
acosHermA = acos(Hermitian(A))
12451275
return isa(acosHermA, Hermitian) ? copytri!(parent(acosHermA), 'U', true) : parent(acosHermA)
12461276
end
@@ -1271,7 +1301,9 @@ julia> asin(sin([0.5 0.1; -0.2 0.3]))
12711301
```
12721302
"""
12731303
function asin(A::AbstractMatrix)
1274-
if ishermitian(A)
1304+
if isdiag(A)
1305+
return applydiagonal(asin, A)
1306+
elseif ishermitian(A)
12751307
asinHermA = asin(Hermitian(A))
12761308
return isa(asinHermA, Hermitian) ? copytri!(parent(asinHermA), 'U', true) : parent(asinHermA)
12771309
end
@@ -1302,7 +1334,9 @@ julia> atan(tan([0.5 0.1; -0.2 0.3]))
13021334
```
13031335
"""
13041336
function atan(A::AbstractMatrix)
1305-
if ishermitian(A)
1337+
if isdiag(A)
1338+
return applydiagonal(atan, A)
1339+
elseif ishermitian(A)
13061340
return copytri!(parent(atan(Hermitian(A))), 'U', true)
13071341
end
13081342
SchurF = Schur{Complex}(schur(A))
@@ -1320,7 +1354,9 @@ logarithmic formulas used to compute this function, see [^AH16_4].
13201354
[^AH16_4]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577)
13211355
"""
13221356
function acosh(A::AbstractMatrix)
1323-
if ishermitian(A)
1357+
if isdiag(A)
1358+
return applydiagonal(acosh, A)
1359+
elseif ishermitian(A)
13241360
acoshHermA = acosh(Hermitian(A))
13251361
return isa(acoshHermA, Hermitian) ? copytri!(parent(acoshHermA), 'U', true) : parent(acoshHermA)
13261362
end
@@ -1339,7 +1375,9 @@ logarithmic formulas used to compute this function, see [^AH16_5].
13391375
[^AH16_5]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577)
13401376
"""
13411377
function asinh(A::AbstractMatrix)
1342-
if ishermitian(A)
1378+
if isdiag(A)
1379+
return applydiagonal(asinh, A)
1380+
elseif ishermitian(A)
13431381
return copytri!(parent(asinh(Hermitian(A))), 'U', true)
13441382
end
13451383
SchurF = Schur{Complex}(schur(A))
@@ -1357,7 +1395,9 @@ logarithmic formulas used to compute this function, see [^AH16_6].
13571395
[^AH16_6]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577)
13581396
"""
13591397
function atanh(A::AbstractMatrix)
1360-
if ishermitian(A)
1398+
if isdiag(A)
1399+
return applydiagonal(atanh, A)
1400+
elseif ishermitian(A)
13611401
return copytri!(parent(atanh(Hermitian(A))), 'U', true)
13621402
end
13631403
SchurF = Schur{Complex}(schur(A))

stdlib/LinearAlgebra/test/dense.jl

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -607,15 +607,16 @@ end
607607
-0.4579038628067864 1.7361475641080275 6.478801851038108])
608608
A3 = convert(Matrix{elty}, [0.25 0.25; 0 0])
609609
A4 = convert(Matrix{elty}, [0 0.02; 0 0])
610+
A5 = convert(Matrix{elty}, [2.0 0; 0 3.0])
610611

611612
cosA1 = convert(Matrix{elty},[-0.18287716254368605 -0.29517205254584633 0.761711400552759;
612613
0.23326967400345625 0.19797853773269333 -0.14758602627292305;
613614
0.23326967400345636 0.6141253742798355 -0.5637328628200653])
614615
sinA1 = convert(Matrix{elty}, [0.2865568596627417 -1.107751980582015 -0.13772915374386513;
615616
-0.6227405671629401 0.2176922827908092 -0.5538759902910078;
616617
-0.6227405671629398 -0.6916051440348725 0.3554214365346742])
617-
@test cos(A1) cosA1
618-
@test sin(A1) sinA1
618+
@test @inferred(cos(A1)) cosA1
619+
@test @inferred(sin(A1)) sinA1
619620

620621
cosA2 = convert(Matrix{elty}, [-0.6331745163802187 0.12878366262380136 -0.17304181968301532;
621622
0.12878366262380136 -0.5596234510748788 0.5210483146041339;
@@ -637,36 +638,36 @@ end
637638
@test sin(A4) sinA4
638639

639640
# Identities
640-
for (i, A) in enumerate((A1, A2, A3, A4))
641-
@test sincos(A) == (sin(A), cos(A))
641+
for (i, A) in enumerate((A1, A2, A3, A4, A5))
642+
@test @inferred(sincos(A)) == (sin(A), cos(A))
642643
@test cos(A)^2 + sin(A)^2 Matrix(I, size(A))
643644
@test cos(A) cos(-A)
644645
@test sin(A) -sin(-A)
645-
@test tan(A) sin(A) / cos(A)
646+
@test @inferred(tan(A)) sin(A) / cos(A)
646647

647648
@test cos(A) real(exp(im*A))
648649
@test sin(A) imag(exp(im*A))
649650
@test cos(A) real(cis(A))
650651
@test sin(A) imag(cis(A))
651-
@test cis(A) cos(A) + im * sin(A)
652+
@test @inferred(cis(A)) cos(A) + im * sin(A)
652653

653-
@test cosh(A) 0.5 * (exp(A) + exp(-A))
654-
@test sinh(A) 0.5 * (exp(A) - exp(-A))
655-
@test cosh(A) cosh(-A)
656-
@test sinh(A) -sinh(-A)
654+
@test @inferred(cosh(A)) 0.5 * (exp(A) + exp(-A))
655+
@test @inferred(sinh(A)) 0.5 * (exp(A) - exp(-A))
656+
@test @inferred(cosh(A)) cosh(-A)
657+
@test @inferred(sinh(A)) -sinh(-A)
657658

658659
# Some of the following identities fail for A3, A4 because the matrices are singular
659-
if i in (1, 2)
660-
@test sec(A) inv(cos(A))
661-
@test csc(A) inv(sin(A))
662-
@test cot(A) inv(tan(A))
663-
@test sech(A) inv(cosh(A))
664-
@test csch(A) inv(sinh(A))
665-
@test coth(A) inv(tanh(A))
660+
if i in (1, 2, 5)
661+
@test @inferred(sec(A)) inv(cos(A))
662+
@test @inferred(csc(A)) inv(sin(A))
663+
@test @inferred(cot(A)) inv(tan(A))
664+
@test @inferred(sech(A)) inv(cosh(A))
665+
@test @inferred(csch(A)) inv(sinh(A))
666+
@test @inferred(coth(A)) inv(@inferred tanh(A))
666667
end
667668
# The following identities fail for A1, A2 due to rounding errors;
668669
# probably needs better algorithm for the general case
669-
if i in (3, 4)
670+
if i in (3, 4, 5)
670671
@test cosh(A)^2 - sinh(A)^2 Matrix(I, size(A))
671672
@test tanh(A) sinh(A) / cosh(A)
672673
end

0 commit comments

Comments
 (0)