diff --git a/src/symmetric.jl b/src/symmetric.jl index e07975ad..9187caf1 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -833,6 +833,12 @@ function svdvals!(A::RealHermSymComplexHerm) return sort!(vals, rev = true) end +#computes U * Diagonal(abs2.(v)) * U' +function _psd_spectral_product(v, U) + Uv = U * Diagonal(v) + return Uv * Uv' # often faster than generic matmul by calling BLAS.herk +end + # Matrix functions ^(A::SymSymTri{<:Complex}, p::Integer) = sympow(A, p) ^(A::SelfAdjoint, p::Integer) = sympow(A, p) @@ -848,7 +854,8 @@ function ^(A::SelfAdjoint, p::Real) isinteger(p) && return integerpow(A, p) F = eigen(A) if all(λ -> λ ≥ 0, F.values) - retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors' + rootpower = map(λ -> λ^(p / 2), F.values) + retmat = _psd_spectral_product(rootpower, F.vectors) return wrappertype(A)(retmat) else retmat = (F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors' @@ -860,7 +867,7 @@ function ^(A::SymSymTri{<:Complex}, p::Real) return Symmetric(schurpow(A, p)) end -for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt) +for func in (:cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt) @eval begin function ($func)(A::SelfAdjoint) F = eigen(A) @@ -870,6 +877,13 @@ for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, end end +function exp(A::SelfAdjoint) + F = eigen(A) + rootexp = map(λ -> exp(λ / 2), F.values) + retmat = _psd_spectral_product(rootexp, F.vectors) + return wrappertype(A)(retmat) +end + function cis(A::SelfAdjoint) F = eigen(A) retmat = F.vectors .* cis.(F.values') * F.vectors' @@ -929,7 +943,8 @@ function sqrt(A::SelfAdjoint; rtol = eps(real(float(eltype(A)))) * size(A, 1)) F = eigen(A) λ₀ = -maximum(abs, F.values) * rtol # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff if all(λ -> λ ≥ λ₀, F.values) - retmat = (F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors' + rootroot = map(λ -> λ < 0 ? zero(λ) : fourthroot(λ), F.values) + retmat = _psd_spectral_product(rootroot, F.vectors) return wrappertype(A)(retmat) else retmat = (F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors'