Skip to content

Commit b43e5e6

Browse files
committed
fix heterogeneous and abstract cases
1 parent dbe1787 commit b43e5e6

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

src/evalpoly.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,19 @@ function evalpoly!(Y::AbstractMatrix, X::AbstractMatrix, p::Union{AbstractVector
6767
return Y
6868
end
6969

70+
# fallback cases: call out-of-place _evalpoly
7071
Base.evalpoly(X::AbstractMatrix, p::Tuple) = _evalpoly(X, p)
7172
Base.evalpoly(X::AbstractMatrix, ::Tuple{}) = zero(one(X)) # dimensionless zero, i.e. 0 * X^0
7273
Base.evalpoly(X::AbstractMatrix, p::AbstractVector) = _evalpoly(X, p)
7374

74-
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{Union{Number, UniformScaling}, Vararg{Union{Number, UniformScaling}}}) =
75-
evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), typeof(_scalarval(p[begin])))), X, p)
76-
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{AbstractMatrix{<:Number}, Vararg{AbstractMatrix{<:Number}}}) =
77-
evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), eltype(p[begin]))), X, p)
75+
# optimized in-place cases, limited to types like homogeneous tuples with length > 1
76+
# where we can reliably deduce the output type (= type of X * p[2]),
77+
# and restricted to StridedMatrix (for now) so that we can be more confident that this is a performance win:
78+
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{T, T, Vararg{T}}) where {T<:Union{Number, UniformScaling}} =
79+
evalpoly!(similar(X, Base.promote_op(*, eltype(X), typeof(_scalarval(p[2])))), X, p)
80+
Base.evalpoly(X::StridedMatrix{<:Number}, p::Tuple{AbstractMatrix{T}, AbstractMatrix{T}, Vararg{AbstractMatrix{T}}}) where {T<:Number} =
81+
evalpoly!(similar(X, Base.promote_op(*, eltype(X), T)), X, p)
7882
Base.evalpoly(X::StridedMatrix{<:Number}, p::AbstractVector{<:Union{Number, UniformScaling}}) =
79-
isempty(p) ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), typeof(_scalarval(p[begin])))), X, p)
83+
length(p) < 2 ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, eltype(X), typeof(_scalarval(p[begin+1])))), X, p)
8084
Base.evalpoly(X::StridedMatrix{<:Number}, p::AbstractVector{<:AbstractMatrix{<:Number}}) =
81-
isempty(p) ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, typeof(one(eltype(X))), eltype(p[begin]))), X, p)
85+
length(p) < 2 ? _evalpoly(X, p) : evalpoly!(similar(X, Base.promote_op(*, eltype(X), eltype(p[begin+1]))), X, p)

test/generic.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -837,23 +837,32 @@ end
837837
end
838838
end
839839

840-
using LinearAlgebra: _evalpoly
841-
naive_evalpoly(X, p) = sum(X^(i-1) * p[i] for i=1:length(p))
840+
using LinearAlgebra: _evalpoly # fallback routine, which we'll test explicitly
841+
842+
# naive sum, a little complicated since X^0 fails if eltype(X) is abstract:
843+
naive_evalpoly(X, p) = length(p) == 1 ? one(X) * p[1] : one(X) * p[1] + sum(X^(i-1) * p[i] for i=2:length(p))
842844

843845
@testset "evalpoly" begin
844-
for X in ([1 2 3;4 5 6;7 8 9], UpperTriangular([1 2 3;0 5 6;0 0 9]), SymTridiagonal([1,2,3],[4,5]))
846+
for X in ([1 2 3;4 5 6;7 8 9], UpperTriangular([1 2 3;0 5 6;0 0 9]),
847+
SymTridiagonal([1,2,3],[4,5]), Real[1 2 3;4 5 6;7 8 9])
845848
@test @inferred(evalpoly(X, ())) == zero(X) == evalpoly(X, Int[])
846849
@test @inferred(evalpoly(X, (17,))) == one(X) * 17
847-
@test @inferred(_evalpoly(X, [1,2,3,4])) == @inferred(evalpoly(X, [1,2,3,4])) ==
848-
@inferred(evalpoly(X, (1,2,3,4))) ==
850+
@test _evalpoly(X, [1,2,3,4]) == evalpoly(X, [1,2,3,4]) == @inferred(evalpoly(X, (1,2,3,4))) ==
849851
naive_evalpoly(X, [1,2,3,4]) == 1*one(X) + 2*X + 3X^2 + 4X^3
850852
@test typeof(evalpoly(X, [1,2,3])) == typeof(evalpoly(X, (1,2,3))) == typeof(_evalpoly(X, [1,2,3])) ==
851853
typeof(X * X)
852854

853-
for N in (1,2,4), p in (rand(-10:10, N), UniformScaling.(rand(-10:10, N)), [rand(-5:5,3,3) for _ = 1:N])
855+
# _evalpoly is not type-stable if eltype(X) is abstract
856+
# because one(Real[...]) returns a Matrix{Int}
857+
if isconcretetype(eltype(X))
858+
@inferred evalpoly(X, [1,2,3,4])
859+
@inferred _evalpoly(X, [1,2,3,4])
860+
end
861+
862+
for N in (1,2,4), p in (Real[1,2], rand(-10:10, N), UniformScaling.(rand(-10:10, N)), [rand(-5:5,3,3) for _ = 1:N])
854863
@test _evalpoly(X, p) == evalpoly(X, p) == evalpoly(X, Tuple(p)) == naive_evalpoly(X, p)
855864
end
856-
for N in (1,2,4), p in (rand(N), UniformScaling.(rand(N)), [rand(3,3) for _ = 1:N])
865+
for N in (1,2,4), p in ((5, 6.7), rand(N), UniformScaling.(rand(N)), [rand(3,3) for _ = 1:N])
857866
@test _evalpoly(X, p) evalpoly(X, p) evalpoly(X, Tuple(p)) naive_evalpoly(X, p)
858867
end
859868

0 commit comments

Comments
 (0)