Skip to content

Commit c6eacfd

Browse files
committed
evalpoly for matrix polynomials
1 parent 9bc292d commit c6eacfd

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

src/generic.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,3 +2091,83 @@ end
20912091
function copytrito!(B::StridedMatrixStride1{T}, A::StridedMatrixStride1{T}, uplo::AbstractChar) where {T<:BlasFloat}
20922092
LAPACK.lacpy!(B, A, uplo)
20932093
end
2094+
2095+
# non-inplace fallback for evalpoly(X, p)
2096+
function _evalpoly(X::AbstractMatrix, p)
2097+
Base.require_one_based_indexing(p)
2098+
p0 = isempty(p) ? Base.reduce_empty_iter(+, p) : p[end]
2099+
Xone = one(X)
2100+
S = Base.promote_op(*, typeof(Xone), typeof(Xone))(Xone) * p0
2101+
for i = length(p)-1:-1:1
2102+
S = X * S + @inbounds(p[i] isa AbstractMatrix ? p[i] : p[i] * I)
2103+
end
2104+
return S
2105+
end
2106+
2107+
_scalarval(x::Number) = x
2108+
_scalarval(x::UniformScaling) = x.λ
2109+
2110+
"""
2111+
evalpoly!(Y::AbstractMatrix, X::AbstractMatrix, p)
2112+
2113+
Evaluate the matrix polynomial ``Y = \\sum_k X^{k-1} p[k]``, storing the result
2114+
in-place in `Y`, for the coefficients `p[k]` (a vector or tuple). The coefficients
2115+
can be scalars, matrices, or [`UniformScaling`](@ref).
2116+
2117+
Similar to `evalpoly`, but may be more efficient by working more in-place. (Some
2118+
allocations may still be required, however.)
2119+
"""
2120+
function evalpoly!(Y::AbstractMatrix, X::AbstractMatrix, p::Union{AbstractVector,Tuple})
2121+
@boundscheck axes(Y,1) == axes(Y,2) == axes(X,1) == axes(X,2)
2122+
Base.require_one_based_indexing(p)
2123+
2124+
N = length(p)
2125+
pN = iszero(N) ? Base.reduce_empty_iter(+, p) : p[N]
2126+
if pN isa AbstractMatrix
2127+
Y .= pN
2128+
elseif N > 1 && p[N-1] isa Union{Number,UniformScaling}
2129+
# initialize Y to p[N-1] I + X p[N], in-place
2130+
Y .= X .* _scalarval(pN)
2131+
for i in axes(Y,1)
2132+
@inbounds Y[i,i] += p[N-1] * I
2133+
end
2134+
N -= 1
2135+
else
2136+
# initialize Y to one(Y) * pN in-place
2137+
for i in axes(Y,1)
2138+
for j in axes(Y,2)
2139+
@inbounds Y[i,j] = zero(Y[i,j])
2140+
end
2141+
@inbounds Y[i,i] += one(Y[i,i]) * pN
2142+
end
2143+
end
2144+
if N > 1
2145+
Z = similar(Y) # workspace for mul!
2146+
for i = N-1:-1:1
2147+
mul!(Z, X, Y)
2148+
if p[i] isa AbstractMatrix
2149+
Y .= p[i] .+ Z
2150+
else
2151+
# Y = p[i] * I + Z, in-place
2152+
Y .= Z
2153+
for j in axes(Y,1)
2154+
@inbounds Y[j,j] += p[i] * I
2155+
end
2156+
end
2157+
end
2158+
end
2159+
return Y
2160+
end
2161+
2162+
Base.evalpoly(X::AbstractMatrix, p::Tuple) = _evalpoly(X, p)
2163+
Base.evalpoly(X::AbstractMatrix, ::Tuple{}) = zero(one(X)) # dimensionless zero, i.e. 0 * x^0
2164+
Base.evalpoly(X::AbstractMatrix, p::AbstractVector) = _evalpoly(X, p)
2165+
2166+
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{Union{Number, UniformScaling}, Vararg{Union{Number, UniformScaling}}}) =
2167+
evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), typeof(_scalarval(p[begin])))), X, p)
2168+
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{AbstractMatrix{<:Number}, Vararg{AbstractMatrix{<:Number}}}) =
2169+
evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), eltype(p[begin]))), X, p)
2170+
Base.evalpoly(X::StridedMatrix{<:Number}, p::AbstractVector{<:Union{Number, UniformScaling}}) =
2171+
isempty(p) ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), typeof(_scalarval(p[begin])))), X, p)
2172+
Base.evalpoly(X::StridedMatrix{<:Number}, p::AbstractVector{<:AbstractMatrix{<:Number}}) =
2173+
isempty(p) ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), eltype(p[begin]))), X, p)

test/generic.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,4 +837,28 @@ end
837837
end
838838
end
839839

840+
using LinearAlgebra: _evalpoly
841+
naive_evalpoly(X, p) = sum(X^(i-1) * p[i] for i=1:length(p))
842+
843+
@testset "evalpoly" begin
844+
for X in ([1 2 3;4 5 6;7 8 9], UpperTriangular([1 2 3;0 5 6;0 0 9]), SymTridiagonal([1,2,3],[4,5]))
845+
@test @inferred(evalpoly(X, ())) == zero(X) == evalpoly(X, Int[])
846+
@test @inferred(evalpoly(X, (17,))) == one(X) * 17
847+
@test @inferred(_evalpoly(X, [1,2,3,4])) == @inferred(evalpoly(X, [1,2,3,4])) ==
848+
@inferred(evalpoly(X, (1,2,3,4))) ==
849+
naive_evalpoly(X, [1,2,3,4]) == 1*one(X) + 2*X + 3X^2 + 4X^3
850+
@test typeof(evalpoly(X, [1,2,3])) == typeof(evalpoly(X, (1,2,3))) == typeof(_evalpoly(X, [1,2,3])) ==
851+
typeof(X * X)
852+
853+
for N in (1,2,4), p in (rand(-10:10, N), UniformScaling.(rand(-10:10, N)), [rand(-5:5,3,3) for _ = 1:N])
854+
@test _evalpoly(X, p) == evalpoly(X, p) == evalpoly(X, Tuple(p)) == naive_evalpoly(X, p)
855+
end
856+
for N in (1,2,4), p in (rand(N), UniformScaling.(rand(N)), [rand(3,3) for _ = 1:N])
857+
@test _evalpoly(X, p) evalpoly(X, p) evalpoly(X, Tuple(p)) naive_evalpoly(X, p)
858+
end
859+
860+
@test_throws MethodError evalpoly(X, [])
861+
end
862+
end
863+
840864
end # module TestGeneric

0 commit comments

Comments
 (0)