Skip to content

Commit 61f6766

Browse files
committed
Faster vector substitution
1 parent e5df766 commit 61f6766

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

src/subs.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,10 @@ function _add_variables!(p::PolyType, q::PolyType)
136136
return p
137137
end
138138

139-
function monoeval(z::Vector{Int}, vals::AbstractVector)
140-
@assert length(z) == length(vals)
139+
function _mono_eval(z::Vector{Int}, vals::AbstractVector)
140+
if length(z) != length(vals)
141+
error("")
142+
end
141143
if isempty(z)
142144
return one(eltype(vals))^1
143145
end
@@ -154,24 +156,24 @@ function monoeval(z::Vector{Int}, vals::AbstractVector)
154156
return val
155157
end
156158

157-
_subs(st, ::Variable, vals) = monoeval([1], vals::AbstractVector)
158-
_subs(st, m::Monomial, vals) = monoeval(m.z, vals::AbstractVector)
159-
function _subs(st, t::_Term, vals)
160-
return MP.coefficient(t) * monoeval(MP.monomial(t).z, vals::AbstractVector)
159+
MP.substitute(::MP.AbstractSubstitutionType, ::Variable, vals::AbstractVector) = _mono_eval((1,), vals)
160+
MP.substitute(::MP.AbstractSubstitutionType, m::Monomial, vals::AbstractVector) = _mono_eval(m.z, vals)
161+
function MP.substitute(st::MP.AbstractSubstitutionType, t::_Term, vals::AbstractVector)
162+
return MP.coefficient(t) * MP.substitute(st, MP.monomial(t), vals)
161163
end
162-
function _subs(
164+
function MP.substitute(
163165
::MP.Eval,
164166
p::Polynomial{V,M,T},
165167
vals::AbstractVector{S},
166168
) where {V,M,T,S}
167169
# I need to check for iszero otherwise I get : ArgumentError: reducing over an empty collection is not allowed
168170
if iszero(p)
169-
zero(Base.promote_op(*, S, T))
171+
zero(MA.promote_operation(*, S, T))
170172
else
171-
sum(i -> p.a[i] * monoeval(p.x.Z[i], vals), eachindex(p.a))
173+
sum(i -> p.a[i] * _mono_eval(p.x.Z[i], vals), eachindex(p.a))
172174
end
173175
end
174-
function _subs(
176+
function MP.substitute(
175177
::MP.Subs,
176178
p::Polynomial{V,M,T},
177179
vals::AbstractVector{S},
@@ -182,7 +184,7 @@ function _subs(
182184
mergevars_of(Variable{V,M}, vals)[1],
183185
)
184186
for i in eachindex(p.a)
185-
MA.operate!(+, q, p.a[i] * monoeval(p.x.Z[i], vals))
187+
MA.operate!(+, q, p.a[i] * _mono_eval(p.x.Z[i], vals))
186188
end
187189
return q
188190
end
@@ -197,12 +199,20 @@ function MA.promote_operation(
197199
return MA.promote_operation(*, U, Monomial{V,M})
198200
end
199201

202+
function MP.substitute(
203+
st::MP.AbstractSubstitutionType,
204+
p::PolyType,
205+
s::MP.AbstractSubstitution...,
206+
)
207+
return MP.substitute(st, p, subsmap(st, MP.variables(p), s))
208+
end
209+
200210
function MP.substitute(
201211
st::MP.AbstractSubstitutionType,
202212
p::PolyType,
203213
s::MP.Substitutions,
204214
)
205-
return _subs(st, p, subsmap(st, MP.variables(p), s))
215+
return MP.substitute(st, p, subsmap(st, MP.variables(p), s))
206216
end
207217

208218
(v::Variable)(s::MP.AbstractSubstitution...) = MP.substitute(MP.Eval(), v, s)
@@ -215,20 +225,20 @@ function (p::Monomial)(x::NTuple{N,<:Number}) where {N}
215225
return MP.substitute(MP.Eval(), p, variables(p) => x)
216226
end
217227
function (p::Monomial)(x::AbstractVector{<:Number})
218-
return MP.substitute(MP.Eval(), p, variables(p) => x)
228+
return MP.substitute(MP.Eval(), p, x)
219229
end
220230
(p::Monomial)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x)
221231
function (p::_Term)(x::NTuple{N,<:Number}) where {N}
222232
return MP.substitute(MP.Eval(), p, variables(p) => x)
223233
end
224234
function (p::_Term)(x::AbstractVector{<:Number})
225-
return MP.substitute(MP.Eval(), p, variables(p) => x)
235+
return MP.substitute(MP.Eval(), p, x)
226236
end
227237
(p::_Term)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x)
228238
function (p::Polynomial)(x::NTuple{N,<:Number}) where {N}
229239
return MP.substitute(MP.Eval(), p, variables(p) => x)
230240
end
231241
function (p::Polynomial)(x::AbstractVector{<:Number})
232-
return MP.substitute(MP.Eval(), p, variables(p) => x)
242+
return MP.substitute(MP.Eval(), p, x)
233243
end
234244
(p::Polynomial)(x::Number...) = MP.substitute(MP.Eval(), p, variables(p) => x)

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,23 @@ using MultivariatePolynomials
33
using Test
44
using LinearAlgebra
55

6+
function alloc_test_lt(f, n)
7+
f() # compile
8+
@test n >= @allocated f()
9+
end
10+
611
# TODO move to MP
12+
@testset "See https://github.com/jump-dev/SumOfSquares.jl/issues/388" begin
13+
@polyvar x[1:3]
14+
p = sum(x)
15+
v = map(_ -> 1, x)
16+
# I get 208 but let's give some margin
17+
alloc_test_lt(() -> substitute(Eval(), p, x => v), 300)
18+
alloc_test_lt(() -> p(x => v), 300)
19+
alloc_test_lt(() -> substitute(Eval(), p, v), 0)
20+
alloc_test_lt(() -> p(v), 0)
21+
end
22+
723
@testset "Issue #70" begin
824
@ncpolyvar y0 y1 x0 x1
925
p = x1 * x0 * x1

0 commit comments

Comments
 (0)