Skip to content

Commit 5095cdb

Browse files
committed
comments
change some input tests remove redundant comment include ComplexF16 in tests fix unchanged test names and docs improve allocations
1 parent d9fb748 commit 5095cdb

File tree

4 files changed

+29
-15
lines changed

4 files changed

+29
-15
lines changed

src/implementations/exponential.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@ copy_input(::typeof(exponential), A::Diagonal) = copy(A)
88

99
function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm)
1010
m, n = size(A)
11-
m == n || throw(DimensionMismatch("square input matrix expected"))
12-
@assert expA isa AbstractMatrix
11+
m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)"))
1312
@check_size(expA, (m, m))
1413
return @check_scalar(expA, A)
1514
end
1615

16+
function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm)
17+
m, n = size(A)
18+
@assert m == n && isdiag(A)
19+
@assert expA isa Diagonal
20+
@check_size(expA, (m, m))
21+
@check_scalar(expA, A)
22+
return nothing
23+
end
24+
1725
# Outputs
1826
# -------
1927
function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm)
@@ -22,6 +30,10 @@ function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::Abstract
2230
return expA
2331
end
2432

33+
function initialize_output(::typeof(exponential!), A::Diagonal, ::DiagonalAlgorithm)
34+
return similar(A)
35+
end
36+
2537
# Implementation
2638
# --------------
2739
function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaLA)
@@ -31,20 +43,24 @@ end
3143

3244
function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh)
3345
D, V = eigh_full(A, alg.eigh_alg)
34-
copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V))
46+
iV = inv(V)
47+
map!(exp, diagview(D), diagview(D))
48+
mul!(expA, rmul!(V, D), iV)
3549
return expA
3650
end
3751

3852
function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig)
3953
D, V = eig_full(A, alg.eig_alg)
40-
copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V))
54+
iV = inv(V)
55+
map!(exp, diagview(D), diagview(D))
56+
mul!(expA, rmul!(V, D), iV)
4157
return expA
4258
end
4359

4460
# Diagonal logic
4561
# --------------
4662
function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm)
4763
check_input(exponential!, A, expA, alg)
48-
copyto!(expA, Diagonal(LinearAlgebra.exp.(diagview(A))))
64+
map!(exp, diagview(expA), diagview(A))
4965
return expA
5066
end

src/interface/decompositions.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,16 +325,15 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi}
325325
"""
326326
ExponentialViaLA()
327327
328-
Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`.
329-
The `qr_alg` specifies which QR-decomposition implementation to use.
328+
Algorithm type to denote finding the exponential of `A` via the implementation of `LinearAlgebra`.
330329
"""
331330
@algdef ExponentialViaLA
332331

333332
"""
334333
ExponentialViaEigh()
335334
336-
Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`.
337-
The `qr_alg` specifies which QR-decomposition implementation to use.
335+
Algorithm type to denote finding the exponential `A` by computing the hermitian eigendecomposition of `A`.
336+
The `eigh_alg` specifies which hermitian eigendecomposition implementation to use.
338337
"""
339338
struct ExponentialViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm
340339
eigh_alg::A
@@ -348,8 +347,8 @@ end
348347
"""
349348
ExponentialViaEig()
350349
351-
Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`.
352-
The `qr_alg` specifies which QR-decomposition implementation to use.
350+
Algorithm type to denote finding the exponential `A` by computing the eigendecomposition of `A`.
351+
The `eig_alg` specifies which eigendecomposition implementation to use.
353352
"""
354353
struct ExponentialViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm
355354
eig_alg::A

src/interface/exponential.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Exponential functions
22
# --------------
33
@functiondef exponential
4-
# @algdef exponential!
54

65
# Algorithm selection
76
# -------------------

test/exponential.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ using MatrixAlgebraKit: diagview
66
using LinearAlgebra
77

88
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9-
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
9+
GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat})
1010

11-
@testset "exp! for T = $T" for T in BLASFloats
11+
@testset "exponential! for T = $T" for T in BLASFloats
1212
rng = StableRNG(123)
1313
m = 2
1414

@@ -30,7 +30,7 @@ GenericFloats = (Float16, BigFloat, Complex{BigFloat})
3030
end
3131
end
3232
33-
@testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
33+
@testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
3434
rng = StableRNG(123)
3535
atol = sqrt(eps(real(T)))
3636
m = 54

0 commit comments

Comments
 (0)