Skip to content

Commit f6ec125

Browse files
authored
Post filter for putinar (#288)
1 parent bbeabc7 commit f6ec125

File tree

2 files changed

+156
-24
lines changed

2 files changed

+156
-24
lines changed

src/Certificate/newton_polytope.jl

Lines changed: 134 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,17 @@ function minus_shift(d::DegreeBounds, p::MP.AbstractPolynomialLike)
350350
)
351351
end
352352

353-
_combine_sign(a, b) = (a == b ? a : zero(a))
353+
function _combine_sign(a, b)
354+
if ismissing(a) || ismissing(b)
355+
return missing
356+
else
357+
return a == b ? a : missing
358+
end
359+
end
354360

355361
_sign(a::Number) = sign(a)
356362
# Can be for instance a JuMP or MOI function so the sign can be anything
357-
_sign(a) = 0
363+
_sign(a) = missing
358364

359365
function deg_sign(deg, p, d)
360366
sgn = nothing
@@ -415,7 +421,7 @@ function deg_range(deg, p, gs, gram_deg, truncation)
415421
d_max = d
416422
end
417423
end
418-
if !isnothing(sign) && (iszero(sign) || (iseven(d_max) && sign == 1))
424+
if !isnothing(sign) && (ismissing(sign) || (iseven(d_max) && sign == 1))
419425
return true, d_max
420426
else
421427
return false, d_max - 1
@@ -525,14 +531,131 @@ function half_newton_polytope(
525531
gs::AbstractVector{<:MP.AbstractPolynomialLike},
526532
vars,
527533
maxdegree,
528-
newton::NewtonFilter{<:NewtonDegreeBounds},
534+
::NewtonFilter{<:NewtonDegreeBounds},
529535
)
530-
# TODO
531-
return half_newton_polytope(
532-
p,
533-
gs,
534-
vars,
535-
maxdegree,
536-
newton.outer_approximation,
536+
bounds = putinar_degree_bounds(p, gs, vars, maxdegree)
537+
bases = [multiplier_basis(g, bounds).monomials for g in gs]
538+
push!(
539+
bases,
540+
maxdegree_gram_basis(MB.MonomialBasis, _half(bounds)).monomials,
537541
)
542+
gs = copy(gs)
543+
push!(gs, one(eltype(gs)))
544+
filtered_bases = post_filter(p, gs, bases)
545+
# The last one will be recomputed by the ideal certificate
546+
return MB.MonomialBasis.(filtered_bases[1:(end-1)])
547+
end
548+
struct SignCount
549+
unknown::Int
550+
positive::Int
551+
negative::Int
552+
end
553+
SignCount() = SignCount(0, 0, 0)
554+
function _sign(c::SignCount)
555+
if !iszero(c.unknown)
556+
return missing
557+
elseif iszero(c.positive)
558+
return -1
559+
elseif iszero(c.negative)
560+
return -1
561+
else
562+
return missing
563+
end
564+
end
565+
566+
function add(c::SignCount, ::Missing, Δ)
567+
@assert c.unknown >= -Δ
568+
return SignCount(c.unknown + Δ, c.positive, c.negative)
569+
end
570+
function add(c::SignCount, a::Number, Δ)
571+
if a > 0
572+
@assert c.positive >= -Δ
573+
return SignCount(c.unknown, c.positive + Δ, c.negative)
574+
elseif a < 0
575+
@assert c.negative >= -Δ
576+
return SignCount(c.unknown, c.positive, c.negative + Δ)
577+
elseif iszero(a)
578+
error(
579+
"A polynomial should never contain a term with zero coefficient but found `$a`.",
580+
)
581+
else
582+
error("Cannot determine sign of `$a`.")
583+
end
584+
end
585+
function add(counter, sign, mono, Δ)
586+
count = get(counter, mono, SignCount())
587+
return counter[mono] = add(count, sign, Δ)
588+
end
589+
590+
function increase(counter, generator_sign, monos, mult)
591+
for a in monos
592+
for b in monos
593+
sign = (a != b) ? missing : generator_sign
594+
add(counter, sign, a * b * mult, 1)
595+
end
596+
end
597+
end
598+
599+
function post_filter(poly, generators, multipliers_gram_monos)
600+
counter = Dict{MP.monomialtype(poly),SignCount}()
601+
for t in MP.terms(poly)
602+
coef = SignCount()
603+
counter[MP.monomial(t)] = add(coef, _sign(MP.coefficient(t)), 1)
604+
end
605+
for (mult, gram_monos) in zip(generators, multipliers_gram_monos)
606+
for t in MP.terms(mult)
607+
sign = -_sign(MP.coefficient(t))
608+
mono = MP.monomial(t)
609+
increase(counter, sign, gram_monos, mono)
610+
end
611+
end
612+
function decrease(sign, mono)
613+
count = add(counter, sign, mono, -1)
614+
count_sign = _sign(count)
615+
# This means the `counter` has a sign and it didn't have a sign before
616+
# so we need to delete back edges
617+
if !ismissing(count_sign) && (ismissing(count) || count != count_sign)
618+
# TODO could see later if deleting the counter improves perf
619+
if haskey(back, mono)
620+
for (i, j) in back[mono]
621+
delete(i, j)
622+
end
623+
end
624+
end
625+
end
626+
back = Dict{eltype(eltype(multipliers_gram_monos)),Vector{Tuple{Int,Int}}}()
627+
keep = [ones(Bool, length(monos)) for monos in multipliers_gram_monos]
628+
function delete(i, j)
629+
if !keep[i][j]
630+
return
631+
end
632+
keep[i][j] = false
633+
a = multipliers_gram_monos[i][j]
634+
for t in MP.terms(generators[i])
635+
sign = -_sign(MP.coefficient(t))
636+
decrease(sign, MP.monomial(t) * a^2)
637+
for (j, b) in enumerate(multipliers_gram_monos[i])
638+
if keep[i][j]
639+
decrease(missing, MP.monomial(t) * a * b)
640+
decrease(missing, MP.monomial(t) * b * a)
641+
end
642+
end
643+
end
644+
end
645+
for i in eachindex(generators)
646+
for t in MP.terms(generators[i])
647+
for (j, mono) in enumerate(multipliers_gram_monos[i])
648+
w = MP.monomial(t) * mono^2
649+
if ismissing(_sign(counter[w]))
650+
push!(get(back, w, Tuple{Int,Int}[]), (i, j))
651+
else
652+
delete(i, j)
653+
end
654+
end
655+
end
656+
end
657+
return [
658+
gram_monos[findall(keep)] for
659+
(keep, gram_monos) in zip(keep, multipliers_gram_monos)
660+
]
538661
end

test/certificate.jl

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -273,31 +273,40 @@ end
273273
end
274274
end
275275

276-
function test_putinar_ijk(i, j, k)
277-
@polyvar x y
276+
function test_putinar_ijk(i, j, k, default::Bool, post_filter::Bool = default)
277+
v = @polyvar x y
278278
poly = x^(2i) + y^(2j + 1)
279279
domain = @set y^(2k + 1) >= 0
280-
set = JuMP.moi_set(SOSCone(), monomials(poly); domain)
281-
processed = Certificate.preprocessed_domain(set.certificate, domain, poly)
282-
for idx in Certificate.preorder_indices(set.certificate, processed)
280+
if default
281+
certificate =
282+
JuMP.moi_set(SOSCone(), monomials(poly); domain).certificate
283+
else
284+
newton = Certificate.NewtonDegreeBounds(tuple())
285+
if post_filter
286+
newton = Certificate.NewtonFilter(newton)
287+
end
288+
cert = Certificate.Newton(SOSCone(), MB.MonomialBasis, newton)
289+
certificate = Certificate.Putinar(cert, cert, max(2i, 2j + 1, 2k + 1))
290+
end
291+
processed = Certificate.preprocessed_domain(certificate, domain, poly)
292+
for idx in Certificate.preorder_indices(certificate, processed)
283293
monos =
284-
Certificate.multiplier_basis(
285-
set.certificate,
286-
idx,
287-
processed,
288-
).monomials
294+
Certificate.multiplier_basis(certificate, idx, processed).monomials
289295
if k > j
290296
@test isempty(monos)
291297
else
292-
@test monos == MP.monomials([x, y], max(0, min(i, j) - k):(j-k))
298+
w = post_filter ? v[2:2] : v
299+
@test monos == MP.monomials(w, max(0, min(i, j) - k):(j-k))
293300
end
294301
end
295-
icert = Certificate.ideal_certificate(set.certificate)
302+
icert = Certificate.ideal_certificate(certificate)
296303
@test icert isa Certificate.Newton
297304
end
298305

299306
@testset "Putinar $i $j $k" for (i, j, k) in [(1, 1, 2), (1, 3, 2), (3, 2, 1)] #, (4, 2, 1)]
300-
test_putinar_ijk(i, j, k)
307+
@testset "post_filter=$post" for post in [true, false]
308+
test_putinar_ijk(i, j, k, post)
309+
end
301310
end
302311

303312
include("ceg_test.jl")

0 commit comments

Comments
 (0)