Skip to content

Commit 37254ac

Browse files
Merge pull request #3214 from AayushSabharwal/as/hc-poly-of-fn
feat: support polynomials of invertible functions in `HomotopyContinuationExt`
2 parents 5ee46d7 + 13500f2 commit 37254ac

File tree

5 files changed

+232
-34
lines changed

5 files changed

+232
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
133133
StaticArrays = "0.10, 0.11, 0.12, 1.0"
134134
SymbolicIndexingInterface = "0.3.35"
135135
SymbolicUtils = "3.7"
136-
Symbolics = "6.15.4"
136+
Symbolics = "6.19"
137137
URIs = "1"
138138
UnPack = "0.1, 1.0"
139139
Unitful = "1.1"

ext/MTKHomotopyContinuationExt.jl

Lines changed: 179 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,73 @@ end
5454

5555
PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false)
5656

57+
abstract type PolynomialTransformationError <: Exception end
58+
59+
struct MultivarTerm <: PolynomialTransformationError
60+
term::Any
61+
vars::Any
62+
end
63+
64+
function Base.showerror(io::IO, err::MultivarTerm)
65+
println(io,
66+
"Cannot convert system to polynomial: Found term $(err.term) which is a function of multiple unknowns $(err.vars).")
67+
end
68+
69+
struct MultipleTermsOfSameVar <: PolynomialTransformationError
70+
terms::Any
71+
var::Any
72+
end
73+
74+
function Base.showerror(io::IO, err::MultipleTermsOfSameVar)
75+
println(io,
76+
"Cannot convert system to polynomial: Found multiple non-polynomial terms $(err.terms) involving the same unknown $(err.var).")
77+
end
78+
79+
struct SymbolicSolveFailure <: PolynomialTransformationError
80+
term::Any
81+
var::Any
82+
end
83+
84+
function Base.showerror(io::IO, err::SymbolicSolveFailure)
85+
println(io,
86+
"Cannot convert system to polynomial: Unable to symbolically solve $(err.term) for $(err.var).")
87+
end
88+
89+
struct NemoNotLoaded <: PolynomialTransformationError end
90+
91+
function Base.showerror(io::IO, err::NemoNotLoaded)
92+
println(io,
93+
"ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again.")
94+
end
95+
96+
struct VariablesAsPolyAndNonPoly <: PolynomialTransformationError
97+
vars::Any
98+
end
99+
100+
function Base.showerror(io::IO, err::VariablesAsPolyAndNonPoly)
101+
println(io,
102+
"Cannot convert convert system to polynomial: Variables $(err.vars) occur in both polynomial and non-polynomial terms in the system.")
103+
end
104+
57105
struct NotPolynomialError <: Exception
58-
eq::Equation
59-
data::PolynomialData
106+
transformation_err::Union{PolynomialTransformationError, Nothing}
107+
eq::Vector{Equation}
108+
data::Vector{PolynomialData}
60109
end
61110

62111
function Base.showerror(io::IO, err::NotPolynomialError)
63-
println(io,
64-
"Equation $(err.eq) is not a polynomial in the unknowns for the following reasons:")
65-
for (term, reason) in zip(err.data.non_polynomial_terms, err.data.reasons)
66-
println(io, display_reason(reason, term))
112+
if err.transformation_err !== nothing
113+
Base.showerror(io, err.transformation_err)
114+
end
115+
for (eq, data) in zip(err.eq, err.data)
116+
if isempty(data.non_polynomial_terms)
117+
continue
118+
end
119+
println(io,
120+
"Equation $(eq) is not a polynomial in the unknowns for the following reasons:")
121+
for (term, reason) in zip(data.non_polynomial_terms, data.reasons)
122+
println(io, display_reason(reason, term))
123+
end
67124
end
68125
end
69126

@@ -86,7 +143,9 @@ function process_polynomial!(data::PolynomialData, x, wrt)
86143
any(isequal(x), wrt) && return true
87144

88145
if operation(x) in (*, +, -, /)
89-
return all(y -> is_polynomial!(data, y, wrt), arguments(x))
146+
# `map` because `all` will early exit, but we want to search
147+
# through everything to get all the non-polynomial terms
148+
return all(map(y -> is_polynomial!(data, y, wrt), arguments(x)))
90149
end
91150
if operation(x) == (^)
92151
b, p = arguments(x)
@@ -105,10 +164,6 @@ function process_polynomial!(data::PolynomialData, x, wrt)
105164
push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns)
106165
end
107166
base_polynomial = is_polynomial!(data, b, wrt)
108-
if !base_polynomial
109-
push!(data.non_polynomial_terms, x)
110-
push!(data.reasons, NonPolynomialReason.BaseNotPolynomial)
111-
end
112167
return base_polynomial && !exponent_has_unknowns && is_pow_integer
113168
end
114169
push!(data.non_polynomial_terms, x)
@@ -234,6 +289,12 @@ end
234289

235290
SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p
236291

292+
struct PolynomialTransformationData
293+
new_var::BasicSymbolic
294+
term::BasicSymbolic
295+
inv_term::Vector
296+
end
297+
237298
"""
238299
$(TYPEDSIGNATURES)
239300
@@ -265,18 +326,95 @@ function MTK.HomotopyContinuationProblem(
265326
# CSE/hashconsing.
266327
eqs = full_equations(sys)
267328

268-
denoms = []
269-
has_parametric_exponents = false
270-
eqs2 = map(eqs) do eq
329+
polydata = map(eqs) do eq
271330
data = PolynomialData()
272331
process_polynomial!(data, eq.lhs, dvs)
273332
process_polynomial!(data, eq.rhs, dvs)
274-
has_parametric_exponents |= data.has_parametric_exponent
275-
if !isempty(data.non_polynomial_terms)
276-
throw(NotPolynomialError(eq, data))
333+
data
334+
end
335+
336+
has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata)
337+
338+
all_non_poly_terms = mapreduce(d -> d.non_polynomial_terms, vcat, polydata)
339+
unique!(all_non_poly_terms)
340+
341+
var_to_nonpoly = Dict{BasicSymbolic, PolynomialTransformationData}()
342+
343+
is_poly = true
344+
transformation_err = nothing
345+
for t in all_non_poly_terms
346+
# if the term involves multiple unknowns, we can't invert it
347+
dvs_in_term = map(x -> occursin(x, t), dvs)
348+
if count(dvs_in_term) > 1
349+
transformation_err = MultivarTerm(t, dvs[dvs_in_term])
350+
is_poly = false
351+
break
352+
end
353+
# we already have a substitution solving for `var`
354+
var = dvs[findfirst(dvs_in_term)]
355+
if haskey(var_to_nonpoly, var) && !isequal(var_to_nonpoly[var].term, t)
356+
transformation_err = MultipleTermsOfSameVar([t, var_to_nonpoly[var].term], var)
357+
is_poly = false
358+
break
359+
end
360+
# we want to solve `term - new_var` for `var`
361+
new_var = gensym(Symbol(var))
362+
new_var = unwrap(only(@variables $new_var))
363+
invterm = Symbolics.ia_solve(
364+
t - new_var, var; complex_roots = false, periodic_roots = false, warns = false)
365+
# if we can't invert it, quit
366+
if invterm === nothing || isempty(invterm)
367+
transformation_err = SymbolicSolveFailure(t, var)
368+
is_poly = false
369+
break
370+
end
371+
# `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2`
372+
# this just evaluates the constant expressions
373+
invterm = Symbolics.substitute.(invterm, (Dict(),))
374+
# RootsOf implies Symbolics couldn't solve the inner polynomial because
375+
# `Nemo` wasn't loaded.
376+
if any(x -> MTK.iscall(x) && MTK.operation(x) == Symbolics.RootsOf, invterm)
377+
transformation_err = NemoNotLoaded()
378+
is_poly = false
379+
break
277380
end
278-
num, den = handle_rational_polynomials(eq.rhs - eq.lhs, dvs)
381+
var_to_nonpoly[var] = PolynomialTransformationData(new_var, t, invterm)
382+
end
383+
384+
if !is_poly
385+
throw(NotPolynomialError(transformation_err, eqs, polydata))
386+
end
387+
388+
subrules = Dict()
389+
combinations = Vector[]
390+
new_dvs = []
391+
for x in dvs
392+
if haskey(var_to_nonpoly, x)
393+
_data = var_to_nonpoly[x]
394+
subrules[_data.term] = _data.new_var
395+
push!(combinations, _data.inv_term)
396+
push!(new_dvs, _data.new_var)
397+
else
398+
push!(combinations, [x])
399+
push!(new_dvs, x)
400+
end
401+
end
402+
all_solutions = collect.(collect(Iterators.product(combinations...)))
279403

404+
denoms = []
405+
eqs2 = map(eqs) do eq
406+
t = eq.rhs - eq.lhs
407+
t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs))
408+
# the substituted variable occurs outside the substituted term
409+
poly_and_nonpoly = map(dvs) do x
410+
haskey(var_to_nonpoly, x) && occursin(x, t)
411+
end
412+
if any(poly_and_nonpoly)
413+
throw(NotPolynomialError(
414+
VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata))
415+
end
416+
417+
num, den = handle_rational_polynomials(t, new_dvs)
280418
# make factors different elements, otherwise the nonzero factors artificially
281419
# inflate the error of the zero factor.
282420
if iscall(den) && operation(den) == *
@@ -292,16 +430,19 @@ function MTK.HomotopyContinuationProblem(
292430
end
293431

294432
sys2 = MTK.@set sys.eqs = eqs2
433+
MTK.@set! sys2.unknowns = new_dvs
295434
# remove observed equations to avoid adding them in codegen
296435
MTK.@set! sys2.observed = Equation[]
297436
MTK.@set! sys2.substitutions = nothing
298437

299-
nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys2, u0map, parammap;
300-
jac = true, eval_expression, eval_module)
438+
_, u0, p = MTK.process_SciMLProblem(
439+
MTK.EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module)
440+
nlfn = NonlinearFunction{true}(sys2; jac = true, eval_expression, eval_module)
301441

302-
denominator = MTK.build_explicit_observed_function(sys, denoms)
442+
denominator = MTK.build_explicit_observed_function(sys2, denoms)
443+
unpack_solution = MTK.build_explicit_observed_function(sys2, all_solutions)
303444

304-
hvars = symbolics_to_hc.(dvs)
445+
hvars = symbolics_to_hc.(new_dvs)
305446
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
306447

307448
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
@@ -319,7 +460,7 @@ function MTK.HomotopyContinuationProblem(
319460
solver_and_starts = HomotopyContinuation.solver_startsolutions(mtkhsys; kwargs...)
320461
end
321462
return MTK.HomotopyContinuationProblem(
322-
u0, mtkhsys, denominator, sys, obsfn, solver_and_starts)
463+
u0, mtkhsys, denominator, sys, obsfn, solver_and_starts, unpack_solution)
323464
end
324465

325466
"""
@@ -353,25 +494,35 @@ function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
353494
if isempty(realsols)
354495
u = state_values(prob)
355496
retcode = SciMLBase.ReturnCode.ConvergenceFailure
497+
resid = prob.homotopy_continuation_system(u)
356498
else
357499
T = eltype(state_values(prob))
358-
distance, idx = findmin(realsols) do result
500+
distance = T(Inf)
501+
u = state_values(prob)
502+
resid = nothing
503+
for result in realsols
359504
if any(<=(denominator_abstol),
360505
prob.denominator(real.(result.solution), parameter_values(prob)))
361-
return T(Inf)
506+
continue
507+
end
508+
for truesol in prob.unpack_solution(result.solution, parameter_values(prob))
509+
dist = norm(truesol - state_values(prob))
510+
if dist < distance
511+
distance = dist
512+
u = T.(real.(truesol))
513+
resid = T.(real.(prob.homotopy_continuation_system(result.solution)))
514+
end
362515
end
363-
norm(result.solution - state_values(prob))
364516
end
365517
# all roots cause denominator to be zero
366518
if isinf(distance)
367519
u = state_values(prob)
520+
resid = prob.homotopy_continuation_system(u)
368521
retcode = SciMLBase.ReturnCode.Infeasible
369522
else
370-
u = real.(realsols[idx].solution)
371523
retcode = SciMLBase.ReturnCode.Success
372524
end
373525
end
374-
resid = prob.homotopy_continuation_system(u)
375526

376527
return SciMLBase.build_solution(
377528
prob, :HomotopyContinuation, u, resid; retcode, original = sol)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ A type of Nonlinear problem which specializes on polynomial systems and uses
690690
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
691691
create and solve.
692692
"""
693-
struct HomotopyContinuationProblem{uType, H, D, O, SS} <:
693+
struct HomotopyContinuationProblem{uType, H, D, O, SS, U} <:
694694
SciMLBase.AbstractNonlinearProblem{uType, true}
695695
"""
696696
The initial values of states in the system. If there are multiple real roots of
@@ -721,6 +721,12 @@ struct HomotopyContinuationProblem{uType, H, D, O, SS} <:
721721
`HomotopyContinuation.solver_startsystems`.
722722
"""
723723
solver_and_starts::SS
724+
"""
725+
A function which takes a solution of the transformed system, and returns a vector
726+
of solutions for the original system. This is utilized when converting systems
727+
to polynomials.
728+
"""
729+
unpack_solution::U
724730
end
725731

726732
function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...)

test/extensions/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
66
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
77
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
88
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
9+
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
910
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1011
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1112
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"

test/extensions/homotopy_continuation.jl

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,59 @@ end
7575
@testset "Polynomial check and warnings" begin
7676
@variables x = 1.0
7777
@mtkbuild sys = NonlinearSystem([x^1.5 + x^2 - 1 ~ 0])
78-
@test_throws ["Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem(
78+
@test_throws ["Cannot convert", "Unable", "symbolically solve",
79+
"Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem(
7980
sys, [])
8081
@mtkbuild sys = NonlinearSystem([x^x - x ~ 0])
81-
@test_throws ["Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem(
82+
@test_throws ["Cannot convert", "Unable", "symbolically solve",
83+
"Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem(
8284
sys, [])
8385
@mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0])
84-
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
86+
@test_throws ["Cannot convert", "both polynomial", "non-polynomial",
87+
"recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
8588
sys, [])
8689

8790
@variables y = 2.0
8891
@mtkbuild sys = NonlinearSystem([x^2 + y^2 + 2 ~ 0, y ~ sin(x)])
89-
@test_throws ["recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
92+
@test_throws ["Cannot convert", "recognized", "sin", "not a polynomial"] HomotopyContinuationProblem(
9093
sys, [])
94+
95+
@mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2 ~ 0, sin(x + y) ~ 0])
96+
@test_throws ["Cannot convert", "function of multiple unknowns"] HomotopyContinuationProblem(
97+
sys, [])
98+
99+
@mtkbuild sys = NonlinearSystem([sin(x)^2 + 1 ~ 0, cos(y) - cos(x) - 1 ~ 0])
100+
@test_throws ["Cannot convert", "multiple non-polynomial terms", "same unknown"] HomotopyContinuationProblem(
101+
sys, [])
102+
103+
@mtkbuild sys = NonlinearSystem([sin(x^2)^2 + sin(x^2) - 1 ~ 0])
104+
@test_throws ["import Nemo"] HomotopyContinuationProblem(sys, [])
105+
end
106+
107+
import Nemo
108+
109+
@testset "With Nemo" begin
110+
@variables x = 2.0
111+
@mtkbuild sys = NonlinearSystem([sin(x^2)^2 + sin(x^2) - 1 ~ 0])
112+
prob = HomotopyContinuationProblem(sys, [])
113+
@test prob[1] 2.0
114+
sol = solve(prob; threading = false)
115+
_x = sol[1]
116+
@test sin(_x^2)^2 + sin(_x^2) - 10.0 atol=1e-12
117+
end
118+
119+
@testset "Function of polynomial" begin
120+
@variables x=0.25 y=0.125
121+
a = sin(x^2 - 4x + 1)
122+
b = cos(3log(y) + 4)
123+
@mtkbuild sys = NonlinearSystem([(a^2 - 4a * b + 4b^2) / (a - 0.25) ~ 0
124+
(a^2 - 0.75a + 0.125) ~ 0])
125+
prob = HomotopyContinuationProblem(sys, [])
126+
@test prob[x] 0.25
127+
@test prob[y] 0.125
128+
sol = solve(prob; threading = false)
129+
@test sol[a]0.5 atol=1e-6
130+
@test sol[b]0.25 atol=1e-6
91131
end
92132

93133
@testset "Rational functions" begin

0 commit comments

Comments
 (0)