Skip to content

Commit b790dfb

Browse files
feat: add significantly improved polynomial transformation
1 parent cf2d5c6 commit b790dfb

File tree

1 file changed

+164
-113
lines changed

1 file changed

+164
-113
lines changed

src/systems/nonlinear/homotopy_continuation.jl

Lines changed: 164 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -47,50 +47,25 @@ end
4747

4848
abstract type PolynomialTransformationError <: Exception end
4949

50-
struct MultivarTerm <: PolynomialTransformationError
51-
term::Any
52-
vars::Any
50+
struct UnmatchedUnknowns <: PolynomialTransformationError
51+
unmatched::Vector{BasicSymbolic}
5352
end
5453

55-
function Base.showerror(io::IO, err::MultivarTerm)
54+
function Base.showerror(io::IO, err::UnmatchedUnknowns)
5655
println(io,
57-
"Cannot convert system to polynomial: Found term $(err.term) which is a function of multiple unknowns $(err.vars).")
56+
"Cannot convert system to polynomial: could not find terms to solve for unknowns $(err.unmatched).")
5857
end
5958

60-
struct MultipleTermsOfSameVar <: PolynomialTransformationError
61-
terms::Any
62-
var::Any
59+
struct UnmatchedTerms <: PolynomialTransformationError
60+
unmatched::Vector{BasicSymbolic}
6361
end
6462

65-
function Base.showerror(io::IO, err::MultipleTermsOfSameVar)
66-
println(io,
67-
"Cannot convert system to polynomial: Found multiple non-polynomial terms $(err.terms) involving the same unknown $(err.var).")
68-
end
69-
70-
struct SymbolicSolveFailure <: PolynomialTransformationError
71-
term::Any
72-
var::Any
73-
end
74-
75-
function Base.showerror(io::IO, err::SymbolicSolveFailure)
76-
println(io,
77-
"Cannot convert system to polynomial: Unable to symbolically solve $(err.term) for $(err.var).")
78-
end
79-
80-
struct NemoNotLoaded <: PolynomialTransformationError end
81-
82-
function Base.showerror(io::IO, err::NemoNotLoaded)
83-
println(io,
84-
"ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again.")
85-
end
86-
87-
struct VariablesAsPolyAndNonPoly <: PolynomialTransformationError
88-
vars::Any
63+
function Base.showerror(io::IO, err::UnmatchedTerms)
64+
println(io, "Cannot convert system to polynomial: too many non-polynomial terms in system. Unmatched terms are $(err.unmatched).")
8965
end
9066

91-
function Base.showerror(io::IO, err::VariablesAsPolyAndNonPoly)
92-
println(io,
93-
"Cannot convert convert system to polynomial: Variables $(err.vars) occur in both polynomial and non-polynomial terms in the system.")
67+
function no_nemo_warning()
68+
@warn "ModelingToolkit may be able to symbolically solve some non-polynomial terms in this system for all roots if `Nemo` is loaded. Run `import Nemo` and try again to enable this functionality and possibly obtain additional roots."
9469
end
9570

9671
struct NotPolynomialError <: Exception
@@ -216,6 +191,24 @@ struct PolynomialTransformationData
216191
inv_term::Vector{BasicSymbolic}
217192
end
218193

194+
"""
195+
$(TYPEDEF)
196+
197+
Information for how to solve for unknowns involved in non-symbolically-solvable
198+
non-polynomial terms to turn the system into a polynomial. Used in
199+
`PolynomialTransformation`.
200+
"""
201+
struct NonlinearSolveTransformation
202+
"""
203+
The system which solves for the unknowns of the parent system.
204+
"""
205+
sys::NonlinearSystem
206+
"""
207+
The input variables to this system representing solutions of non-polynomial terms.
208+
"""
209+
inputvars::Vector{BasicSymbolic}
210+
end
211+
219212
"""
220213
$(TYPEDEF)
221214
@@ -224,20 +217,19 @@ system.
224217
"""
225218
struct PolynomialTransformation
226219
"""
227-
Substitutions mapping non-polynomial terms to temporary unknowns. The system
228-
is a polynomial in the new unknowns. Currently, each non-polynomial term is a
229-
function of a single unknown of the original system.
220+
The stages in which to recover the solution in terms of the original unknowns, in
221+
order.
230222
"""
231-
substitution_rules::Dict{BasicSymbolic, BasicSymbolic}
223+
solve_stages::Vector{Any}
232224
"""
233-
A vector of expressions involving unknowns of the transformed system, mapping
234-
back to solutions of the original system.
225+
The (previous) stages each stage depends on.
235226
"""
236-
all_solutions::Vector{Vector{BasicSymbolic}}
227+
stage_dependencies::Vector{Vector{Int}}
237228
"""
238-
The new unknowns of the transformed system.
229+
Mapping from terms to new unknowns they are replaced by. The system is a
230+
polynomial in the new unknowns.
239231
"""
240-
new_dvs::Vector{BasicSymbolic}
232+
substitution_rules::Dict{BasicSymbolic, BasicSymbolic}
241233
"""
242234
The polynomial data for each equation.
243235
"""
@@ -270,76 +262,135 @@ function PolynomialTransformation(sys::NonlinearSystem)
270262
d -> d.non_polynomial_terms, vcat, polydata; init = BasicSymbolic[])
271263
unique!(all_non_poly_terms)
272264

273-
# each variable can only be replaced by one non-polynomial expression involving
274-
# that variable. Keep track of this mapping.
275-
var_to_nonpoly = Dict{BasicSymbolic, PolynomialTransformationData}()
265+
all_solo_vars = mapreduce(d -> d.solo_terms, union, polydata; init = Set{BasicSymbolic}())
266+
267+
# Graph matches variables to candidates for unknowns of the polynomial system that
268+
# they occur in. These unknowns can be solo variables that appear outside of
269+
# non-polynomial terms in the system, or non-polynomials.
270+
graph = BipartiteGraph(length(dvs), 0)
271+
# all solo variables are candidates for unknowns
272+
graph_srcs = dvs
273+
graph_dsts = BasicSymbolic[]
274+
for (i, var) in enumerate(dvs)
275+
var in all_solo_vars || continue
276+
push!(graph_dsts, var)
277+
vert = add_vertex!(graph, DST)
278+
add_edge!(graph, i, vert)
279+
end
276280

277-
is_poly = true
278-
transformation_err = nothing
281+
# buffer to prevent reallocations
282+
dvs_in_term = Set()
283+
# for efficient queries
284+
dvs_to_src = Dict(graph_srcs .=> eachindex(graph_srcs))
285+
# build out graph with other non-polynomial terms
279286
for t in all_non_poly_terms
280-
# if the term involves multiple unknowns, we can't invert it
281-
dvs_in_term = map(x -> occursin(x, t), dvs)
282-
if count(dvs_in_term) > 1
283-
transformation_err = MultivarTerm(t, dvs[dvs_in_term])
284-
is_poly = false
285-
break
286-
end
287-
# we already have a substitution solving for `var`
288-
var = dvs[findfirst(dvs_in_term)]
289-
if haskey(var_to_nonpoly, var) && !isequal(var_to_nonpoly[var].term, t)
290-
transformation_err = MultipleTermsOfSameVar([t, var_to_nonpoly[var].term], var)
291-
is_poly = false
292-
break
293-
end
294-
# we want to solve `term - new_var` for `var`
295-
new_var = gensym(Symbol(var))
296-
new_var = unwrap(only(@variables $new_var))
297-
invterm = Symbolics.ia_solve(
298-
t - new_var, var; complex_roots = false, periodic_roots = false, warns = false)
299-
# if we can't invert it, quit
300-
if invterm === nothing || isempty(invterm)
301-
transformation_err = SymbolicSolveFailure(t, var)
302-
is_poly = false
303-
break
304-
end
305-
# `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2`
306-
# this just evaluates the constant expressions
307-
invterm = Symbolics.substitute.(invterm, (Dict(),))
308-
# RootsOf implies Symbolics couldn't solve the inner polynomial because
309-
# `Nemo` wasn't loaded.
310-
if any(x -> iscall(x) && operation(x) == Symbolics.RootsOf, invterm)
311-
transformation_err = NemoNotLoaded()
312-
is_poly = false
313-
break
287+
empty!(dvs_in_term)
288+
vars!(dvs_in_term, t)
289+
intersect!(dvs_in_term, dvs)
290+
push!(graph_dsts, t)
291+
vert = add_vertex!(graph, DST)
292+
for var in dvs_in_term
293+
add_edge!(graph, dvs_to_src[var], vert)
314294
end
315-
var_to_nonpoly[var] = PolynomialTransformationData(new_var, t, invterm)
316295
end
317296

297+
# Match variables to the candidate unknown we'll use to solve for them.
298+
# This is a poor man's version of `structural_simplify`, but if we create
299+
# and simplify a `NonlinearSystem` it makes doing symbolic solving more
300+
# annoying.
301+
matching = BipartiteGraphs.complete(maximal_matching(graph))
302+
inv_matching = invview(matching)
303+
# matching is from destination to source vertices
304+
unassigned_dsts = filter(i -> matching[i] == unassigned, 𝑠vertices(graph))
305+
unassigned_srcs = filter(i -> inv_matching[i] == unassigned, 𝑑vertices(graph))
306+
318307
# return the error instead of throwing it, so the user can choose what to do
319308
# without having to catch the exception
320-
if !is_poly
321-
return NotPolynomialError(transformation_err, eqs, polydata)
309+
if !isempty(unassigned_srcs)
310+
return NotPolynomialError(UnmatchedUnknowns(graph_srcs[unassigned_srcs]), eqs, polydata)
311+
end
312+
if !isempty(unassigned_dsts)
313+
return NotPolynomialError(UnmatchedTerms(graph_dsts[unassigned_dsts]), eqs, polydata)
314+
end
315+
316+
# At this point, the matching is perfect. Find the SCCs so we know
317+
# which terms to solve for which variables.
318+
digraph = DiCMOBiGraph{false}(graph, matching)
319+
var_sccs = Graphs.strongly_connected_components(digraph)
320+
foreach(sort!, var_sccs)
321+
# construct a condensation graph of the SCCs so we can topologically sort them
322+
scc_graph = MatchedCondensationGraph(digraph, var_sccs)
323+
toporder = topological_sort(scc_graph)
324+
var_sccs = var_sccs[toporder]
325+
# get the corresponding terms
326+
term_sccs = map(var_sccs) do scc
327+
map(scc) do src
328+
inv_matching[src]
329+
end
322330
end
323331

332+
# keep track of which previous SCCs each SCC depends on
333+
dependencies = Vector{Int}[]
334+
# the method to solve each stage
335+
solve_stages = []
336+
# mapping from terms to the new unknowns they are replaced by
324337
subrules = Dict{BasicSymbolic, BasicSymbolic}()
325-
# corresponding to each unknown in `dvs`, the list of its possible solutions
326-
# in terms of the new unknown.
327-
combinations = Vector{BasicSymbolic}[]
328-
new_dvs = BasicSymbolic[]
329-
for x in dvs
330-
if haskey(var_to_nonpoly, x)
331-
_data = var_to_nonpoly[x]
332-
# map term to new unknown
333-
subrules[_data.term] = _data.new_var
334-
push!(combinations, _data.inv_term)
335-
push!(new_dvs, _data.new_var)
336-
else
337-
push!(combinations, BasicSymbolic[x])
338-
push!(new_dvs, x)
338+
# if we've already emitted the no nemo warning
339+
warned_no_nemo = false
340+
for (i, (vscc, tscc)) in enumerate(zip(var_sccs, term_sccs))
341+
# dependencies are simply outneighbors
342+
push!(dependencies, collect(Graphs.outneighbors(scc_graph, i)))
343+
344+
# whether the SCC is solvable with a single variable
345+
single_scc_solvable = length(vscc) == 1
346+
# for single-variable SCCs, we use `ia_solve`
347+
if single_scc_solvable
348+
varidx = vscc[]
349+
termidx = tscc[]
350+
var = graph_srcs[varidx]
351+
t = graph_dsts[termidx]
352+
# Create a new variable and representing the non-polynomial term...
353+
new_var = unwrap(similar_variable(var, Symbol(var)))
354+
# ...and solve for `var` in terms of this new variable.
355+
invterm = Symbolics.ia_solve(t - new_var, var; complex_roots = false, periodic_roots = false, warns = false)
356+
# `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2`
357+
# this just evaluates the constant expressions
358+
invterm = Symbolics.substitute.(invterm, (Dict(),))
359+
# if `ia_solve` returns `nothing`, the broadcast above turns it into `(nothing,)`
360+
if invterm === (nothing,) || isempty(invterm)
361+
# if we can't invert it, quit
362+
single_scc_solvable = false
363+
elseif any(x -> iscall(x) && operation(x) == Symbolics.RootsOf, invterm)
364+
# RootsOf implies Symbolics couldn't solve the inner polynomial because
365+
# `Nemo` wasn't loaded.
366+
warned_no_nemo || no_nemo_warning()
367+
warned_no_nemo = true
368+
single_scc_solvable = false
369+
else
370+
subrules[t] = new_var
371+
push!(solve_stages, PolynomialTransformationData(new_var, t, invterm))
372+
end
373+
end
374+
375+
# the SCC was solved with a single variable
376+
single_scc_solvable && continue
377+
378+
# Solve using a `NonlinearSolve`.
379+
vars = graph_srcs[vscc]
380+
ts = graph_dsts[tscc]
381+
# the new variables are inputs to the system, so they're parameters
382+
new_vars = map(vars) do var
383+
toparam.(unwrap.(similar_variable(var, Symbol(var))))
384+
end
385+
eqs = collect(0 .~ (ts .- new_vars))
386+
scc_sys = complete(NonlinearSystem(eqs; name = Symbol(:scc_, i)))
387+
push!(solve_stages, NonlinearSolveTransformation(scc_sys, new_vars))
388+
for (new_var, t) in zip(new_vars, ts)
389+
subrules[t] = new_var
339390
end
340391
end
341-
all_solutions = vec(collect.(collect(Iterators.product(combinations...))))
342-
return PolynomialTransformation(subrules, all_solutions, new_dvs, polydata)
392+
393+
return PolynomialTransformation(solve_stages, dependencies, subrules, polydata)
343394
end
344395

345396
"""
@@ -352,6 +403,15 @@ in the equations, to rule out invalid roots.
352403
struct PolynomialTransformationResult
353404
sys::NonlinearSystem
354405
denominators::Vector{BasicSymbolic}
406+
"""
407+
The stages in which to recover the solution in terms of the original unknowns, in
408+
order.
409+
"""
410+
solve_stages::Vector{Any}
411+
"""
412+
The (previous) stages each stage depends on.
413+
"""
414+
stage_dependencies::Vector{Vector{Int}}
355415
end
356416

357417
"""
@@ -367,22 +427,13 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf
367427
dvs = unknowns(sys)
368428
eqs = full_equations(sys)
369429
polydata = transformation.polydata
370-
new_dvs = transformation.new_dvs
371-
all_solutions = transformation.all_solutions
430+
new_dvs = collect(values(subrules))
372431

373432
eqs2 = Equation[]
374433
denoms = BasicSymbolic[]
375434
for eq in eqs
376435
t = eq.rhs - eq.lhs
377436
t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs))
378-
# the substituted variable occurs outside the substituted term
379-
poly_and_nonpoly = map(dvs) do x
380-
all(!isequal(x), new_dvs) && occursin(x, t)
381-
end
382-
if any(poly_and_nonpoly)
383-
return NotPolynomialError(
384-
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata)
385-
end
386437
num, den = handle_rational_polynomials(t, new_dvs; fraction_cancel_fn)
387438
# make factors different elements, otherwise the nonzero factors artificially
388439
# inflate the error of the zero factor.
@@ -403,7 +454,7 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf
403454
# remove observed equations to avoid adding them in codegen
404455
@set! sys2.observed = Equation[]
405456
@set! sys2.substitutions = nothing
406-
return PolynomialTransformationResult(sys2, denoms)
457+
return PolynomialTransformationResult(sys2, denoms, transformation.solve_stages, transformation.stage_dependencies)
407458
end
408459

409460
"""

0 commit comments

Comments
 (0)