Skip to content

Commit 5306a7a

Browse files
Merge pull request #3249 from AayushSabharwal/as/hc-everywhere
feat: use `HomotopyContinuationProblem` in `NonlinearProblem` if possible
2 parents ab5747f + f428df4 commit 5306a7a

File tree

5 files changed

+609
-405
lines changed

5 files changed

+609
-405
lines changed

ext/MTKHomotopyContinuationExt.jl

Lines changed: 27 additions & 334 deletions
Original file line numberDiff line numberDiff line change
@@ -11,217 +11,6 @@ using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache,
1111

1212
const MTK = ModelingToolkit
1313

14-
function contains_variable(x, wrt)
15-
any(y -> occursin(y, x), wrt)
16-
end
17-
18-
"""
19-
Possible reasons why a term is not polynomial
20-
"""
21-
MTK.EnumX.@enumx NonPolynomialReason begin
22-
NonIntegerExponent
23-
ExponentContainsUnknowns
24-
BaseNotPolynomial
25-
UnrecognizedOperation
26-
end
27-
28-
function display_reason(reason::NonPolynomialReason.T, sym)
29-
if reason == NonPolynomialReason.NonIntegerExponent
30-
pow = arguments(sym)[2]
31-
"In $sym: Exponent $pow is not an integer"
32-
elseif reason == NonPolynomialReason.ExponentContainsUnknowns
33-
pow = arguments(sym)[2]
34-
"In $sym: Exponent $pow contains unknowns of the system"
35-
elseif reason == NonPolynomialReason.BaseNotPolynomial
36-
base = arguments(sym)[1]
37-
"In $sym: Base $base is not a polynomial in the unknowns"
38-
elseif reason == NonPolynomialReason.UnrecognizedOperation
39-
op = operation(sym)
40-
"""
41-
In $sym: Operation $op is not recognized. Allowed polynomial operations are \
42-
`*, /, +, -, ^`.
43-
"""
44-
else
45-
error("This should never happen. Please open an issue in ModelingToolkit.jl.")
46-
end
47-
end
48-
49-
mutable struct PolynomialData
50-
non_polynomial_terms::Vector{BasicSymbolic}
51-
reasons::Vector{NonPolynomialReason.T}
52-
has_parametric_exponent::Bool
53-
end
54-
55-
PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false)
56-
57-
abstract type PolynomialTransformationError <: Exception end
58-
59-
struct MultivarTerm <: PolynomialTransformationError
60-
term::Any
61-
vars::Any
62-
end
63-
64-
function Base.showerror(io::IO, err::MultivarTerm)
65-
println(io,
66-
"Cannot convert system to polynomial: Found term $(err.term) which is a function of multiple unknowns $(err.vars).")
67-
end
68-
69-
struct MultipleTermsOfSameVar <: PolynomialTransformationError
70-
terms::Any
71-
var::Any
72-
end
73-
74-
function Base.showerror(io::IO, err::MultipleTermsOfSameVar)
75-
println(io,
76-
"Cannot convert system to polynomial: Found multiple non-polynomial terms $(err.terms) involving the same unknown $(err.var).")
77-
end
78-
79-
struct SymbolicSolveFailure <: PolynomialTransformationError
80-
term::Any
81-
var::Any
82-
end
83-
84-
function Base.showerror(io::IO, err::SymbolicSolveFailure)
85-
println(io,
86-
"Cannot convert system to polynomial: Unable to symbolically solve $(err.term) for $(err.var).")
87-
end
88-
89-
struct NemoNotLoaded <: PolynomialTransformationError end
90-
91-
function Base.showerror(io::IO, err::NemoNotLoaded)
92-
println(io,
93-
"ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again.")
94-
end
95-
96-
struct VariablesAsPolyAndNonPoly <: PolynomialTransformationError
97-
vars::Any
98-
end
99-
100-
function Base.showerror(io::IO, err::VariablesAsPolyAndNonPoly)
101-
println(io,
102-
"Cannot convert convert system to polynomial: Variables $(err.vars) occur in both polynomial and non-polynomial terms in the system.")
103-
end
104-
105-
struct NotPolynomialError <: Exception
106-
transformation_err::Union{PolynomialTransformationError, Nothing}
107-
eq::Vector{Equation}
108-
data::Vector{PolynomialData}
109-
end
110-
111-
function Base.showerror(io::IO, err::NotPolynomialError)
112-
if err.transformation_err !== nothing
113-
Base.showerror(io, err.transformation_err)
114-
end
115-
for (eq, data) in zip(err.eq, err.data)
116-
if isempty(data.non_polynomial_terms)
117-
continue
118-
end
119-
println(io,
120-
"Equation $(eq) is not a polynomial in the unknowns for the following reasons:")
121-
for (term, reason) in zip(data.non_polynomial_terms, data.reasons)
122-
println(io, display_reason(reason, term))
123-
end
124-
end
125-
end
126-
127-
function is_polynomial!(data, y, wrt)
128-
process_polynomial!(data, y, wrt)
129-
isempty(data.reasons)
130-
end
131-
132-
"""
133-
$(TYPEDSIGNATURES)
134-
135-
Return information about the polynmial `x` with respect to variables in `wrt`,
136-
writing said information to `data`.
137-
"""
138-
function process_polynomial!(data::PolynomialData, x, wrt)
139-
x = unwrap(x)
140-
symbolic_type(x) == NotSymbolic() && return true
141-
iscall(x) || return true
142-
contains_variable(x, wrt) || return true
143-
any(isequal(x), wrt) && return true
144-
145-
if operation(x) in (*, +, -, /)
146-
# `map` because `all` will early exit, but we want to search
147-
# through everything to get all the non-polynomial terms
148-
return all(map(y -> is_polynomial!(data, y, wrt), arguments(x)))
149-
end
150-
if operation(x) == (^)
151-
b, p = arguments(x)
152-
is_pow_integer = symtype(p) <: Integer
153-
if !is_pow_integer
154-
push!(data.non_polynomial_terms, x)
155-
push!(data.reasons, NonPolynomialReason.NonIntegerExponent)
156-
end
157-
if symbolic_type(p) != NotSymbolic()
158-
data.has_parametric_exponent = true
159-
end
160-
161-
exponent_has_unknowns = contains_variable(p, wrt)
162-
if exponent_has_unknowns
163-
push!(data.non_polynomial_terms, x)
164-
push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns)
165-
end
166-
base_polynomial = is_polynomial!(data, b, wrt)
167-
return base_polynomial && !exponent_has_unknowns && is_pow_integer
168-
end
169-
push!(data.non_polynomial_terms, x)
170-
push!(data.reasons, NonPolynomialReason.UnrecognizedOperation)
171-
return false
172-
end
173-
174-
"""
175-
$(TYPEDSIGNATURES)
176-
177-
Given a `x`, a polynomial in variables in `wrt` which may contain rational functions,
178-
express `x` as a single rational function with polynomial `num` and denominator `den`.
179-
Return `(num, den)`.
180-
"""
181-
function handle_rational_polynomials(x, wrt)
182-
x = unwrap(x)
183-
symbolic_type(x) == NotSymbolic() && return x, 1
184-
iscall(x) || return x, 1
185-
contains_variable(x, wrt) || return x, 1
186-
any(isequal(x), wrt) && return x, 1
187-
188-
# simplify_fractions cancels out some common factors
189-
# and expands (a / b)^c to a^c / b^c, so we only need
190-
# to handle these cases
191-
x = simplify_fractions(x)
192-
op = operation(x)
193-
args = arguments(x)
194-
195-
if op == /
196-
# numerator and denominator are trivial
197-
num, den = args
198-
# but also search for rational functions in numerator
199-
n, d = handle_rational_polynomials(num, wrt)
200-
num, den = n, den * d
201-
elseif op == +
202-
num = 0
203-
den = 1
204-
205-
# we don't need to do common denominator
206-
# because we don't care about cases where denominator
207-
# is zero. The expression is zero when all the numerators
208-
# are zero.
209-
for arg in args
210-
n, d = handle_rational_polynomials(arg, wrt)
211-
num += n
212-
den *= d
213-
end
214-
else
215-
return x, 1
216-
end
217-
# if the denominator isn't a polynomial in `wrt`, better to not include it
218-
# to reduce the size of the gcd polynomial
219-
if !contains_variable(den, wrt)
220-
return num / den, 1
221-
end
222-
return num, den
223-
end
224-
22514
"""
22615
$(TYPEDSIGNATURES)
22716
@@ -289,12 +78,6 @@ end
28978

29079
SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p
29180

292-
struct PolynomialTransformationData
293-
new_var::BasicSymbolic
294-
term::BasicSymbolic
295-
inv_term::Vector
296-
end
297-
29881
"""
29982
$(TYPEDSIGNATURES)
30083
@@ -312,128 +95,37 @@ Keyword arguments:
31295
All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`.
31396
"""
31497
function MTK.HomotopyContinuationProblem(
315-
sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false,
316-
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
98+
sys::NonlinearSystem, u0map, parammap = nothing; kwargs...)
99+
prob = MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap; kwargs...)
100+
prob isa MTK.HomotopyContinuationProblem || throw(prob)
101+
return prob
102+
end
103+
104+
function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...)
317105
if !iscomplete(sys)
318106
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
319107
end
320-
321-
dvs = unknowns(sys)
322-
# we need to consider `full_equations` because observed also should be
323-
# polynomials (if used in equations) and we don't know if observed is used
324-
# in denominator.
325-
# This is not the most efficient, and would be improved significantly with
326-
# CSE/hashconsing.
327-
eqs = full_equations(sys)
328-
329-
polydata = map(eqs) do eq
330-
data = PolynomialData()
331-
process_polynomial!(data, eq.lhs, dvs)
332-
process_polynomial!(data, eq.rhs, dvs)
333-
data
108+
transformation = MTK.PolynomialTransformation(sys)
109+
if transformation isa MTK.NotPolynomialError
110+
return transformation
334111
end
335-
336-
has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata)
337-
338-
all_non_poly_terms = mapreduce(d -> d.non_polynomial_terms, vcat, polydata)
339-
unique!(all_non_poly_terms)
340-
341-
var_to_nonpoly = Dict{BasicSymbolic, PolynomialTransformationData}()
342-
343-
is_poly = true
344-
transformation_err = nothing
345-
for t in all_non_poly_terms
346-
# if the term involves multiple unknowns, we can't invert it
347-
dvs_in_term = map(x -> occursin(x, t), dvs)
348-
if count(dvs_in_term) > 1
349-
transformation_err = MultivarTerm(t, dvs[dvs_in_term])
350-
is_poly = false
351-
break
352-
end
353-
# we already have a substitution solving for `var`
354-
var = dvs[findfirst(dvs_in_term)]
355-
if haskey(var_to_nonpoly, var) && !isequal(var_to_nonpoly[var].term, t)
356-
transformation_err = MultipleTermsOfSameVar([t, var_to_nonpoly[var].term], var)
357-
is_poly = false
358-
break
359-
end
360-
# we want to solve `term - new_var` for `var`
361-
new_var = gensym(Symbol(var))
362-
new_var = unwrap(only(@variables $new_var))
363-
invterm = Symbolics.ia_solve(
364-
t - new_var, var; complex_roots = false, periodic_roots = false, warns = false)
365-
# if we can't invert it, quit
366-
if invterm === nothing || isempty(invterm)
367-
transformation_err = SymbolicSolveFailure(t, var)
368-
is_poly = false
369-
break
370-
end
371-
# `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2`
372-
# this just evaluates the constant expressions
373-
invterm = Symbolics.substitute.(invterm, (Dict(),))
374-
# RootsOf implies Symbolics couldn't solve the inner polynomial because
375-
# `Nemo` wasn't loaded.
376-
if any(x -> MTK.iscall(x) && MTK.operation(x) == Symbolics.RootsOf, invterm)
377-
transformation_err = NemoNotLoaded()
378-
is_poly = false
379-
break
380-
end
381-
var_to_nonpoly[var] = PolynomialTransformationData(new_var, t, invterm)
382-
end
383-
384-
if !is_poly
385-
throw(NotPolynomialError(transformation_err, eqs, polydata))
386-
end
387-
388-
subrules = Dict()
389-
combinations = Vector[]
390-
new_dvs = []
391-
for x in dvs
392-
if haskey(var_to_nonpoly, x)
393-
_data = var_to_nonpoly[x]
394-
subrules[_data.term] = _data.new_var
395-
push!(combinations, _data.inv_term)
396-
push!(new_dvs, _data.new_var)
397-
else
398-
push!(combinations, [x])
399-
push!(new_dvs, x)
400-
end
401-
end
402-
all_solutions = collect.(collect(Iterators.product(combinations...)))
403-
404-
denoms = []
405-
eqs2 = map(eqs) do eq
406-
t = eq.rhs - eq.lhs
407-
t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs))
408-
# the substituted variable occurs outside the substituted term
409-
poly_and_nonpoly = map(dvs) do x
410-
haskey(var_to_nonpoly, x) && occursin(x, t)
411-
end
412-
if any(poly_and_nonpoly)
413-
throw(NotPolynomialError(
414-
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata))
415-
end
416-
417-
num, den = handle_rational_polynomials(t, new_dvs)
418-
# make factors different elements, otherwise the nonzero factors artificially
419-
# inflate the error of the zero factor.
420-
if iscall(den) && operation(den) == *
421-
for arg in arguments(den)
422-
# ignore constant factors
423-
symbolic_type(arg) == NotSymbolic() && continue
424-
push!(denoms, abs(arg))
425-
end
426-
elseif symbolic_type(den) != NotSymbolic()
427-
push!(denoms, abs(den))
428-
end
429-
return 0 ~ num
112+
result = MTK.transform_system(sys, transformation)
113+
if result isa MTK.NotPolynomialError
114+
return result
430115
end
116+
MTK.HomotopyContinuationProblem(sys, transformation, result, u0map, parammap; kwargs...)
117+
end
431118

432-
sys2 = MTK.@set sys.eqs = eqs2
433-
MTK.@set! sys2.unknowns = new_dvs
434-
# remove observed equations to avoid adding them in codegen
435-
MTK.@set! sys2.observed = Equation[]
436-
MTK.@set! sys2.substitutions = nothing
119+
function MTK.HomotopyContinuationProblem(
120+
sys::MTK.NonlinearSystem, transformation::MTK.PolynomialTransformation,
121+
result::MTK.PolynomialTransformationResult, u0map,
122+
parammap = nothing; eval_expression = false,
123+
eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...)
124+
sys2 = result.sys
125+
denoms = result.denominators
126+
polydata = transformation.polydata
127+
new_dvs = transformation.new_dvs
128+
all_solutions = transformation.all_solutions
437129

438130
_, u0, p = MTK.process_SciMLProblem(
439131
MTK.EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module)
@@ -443,10 +135,11 @@ function MTK.HomotopyContinuationProblem(
443135
unpack_solution = MTK.build_explicit_observed_function(sys2, all_solutions)
444136

445137
hvars = symbolics_to_hc.(new_dvs)
446-
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
138+
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(new_dvs))
447139

448140
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
449141

142+
has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata)
450143
if has_parametric_exponents
451144
if warn_parametric_exponent
452145
@warn """

0 commit comments

Comments
 (0)