Skip to content

Commit 2770a8f

Browse files
Merge pull request #3114 from AayushSabharwal/as/homotopy-cont
feat: initial implementation of HomotopyContinuation interface
2 parents 1f2d943 + ecf01b3 commit 2770a8f

File tree

8 files changed

+347
-4
lines changed

8 files changed

+347
-4
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
11+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1112
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1213
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1314
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -61,12 +62,14 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6162
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6263
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6364
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
65+
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
6466
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
6567

6668
[extensions]
6769
MTKBifurcationKitExt = "BifurcationKit"
6870
MTKChainRulesCoreExt = "ChainRulesCore"
6971
MTKDeepDiffsExt = "DeepDiffs"
72+
MTKHomotopyContinuationExt = "HomotopyContinuation"
7073
MTKLabelledArraysExt = "LabelledArrays"
7174

7275
[compat]
@@ -76,6 +79,7 @@ BifurcationKit = "0.3"
7679
BlockArrays = "1.1"
7780
ChainRulesCore = "1"
7881
Combinatorics = "1"
82+
CommonSolve = "0.2.4"
7983
Compat = "3.42, 4"
8084
ConstructionBase = "1"
8185
DataInterpolations = "6.4"
@@ -97,6 +101,7 @@ ForwardDiff = "0.10.3"
97101
FunctionWrappers = "1.1"
98102
FunctionWrappersWrappers = "0.1"
99103
Graphs = "1.5.2"
104+
HomotopyContinuation = "2.11"
100105
InteractiveUtils = "1"
101106
JuliaFormatter = "1.0.47"
102107
JumpProcesses = "9.13.1"

ext/MTKHomotopyContinuationExt.jl

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
module MTKHomotopyContinuationExt
2+
3+
using ModelingToolkit
4+
using ModelingToolkit.SciMLBase
5+
using ModelingToolkit.Symbolics: unwrap, symtype
6+
using ModelingToolkit.SymbolicIndexingInterface
7+
using ModelingToolkit.DocStringExtensions
8+
using HomotopyContinuation
9+
using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0,
10+
get_u0_p, check_eqs_u0, CommonSolve
11+
12+
const MTK = ModelingToolkit
13+
14+
function contains_variable(x, wrt)
15+
any(y -> occursin(y, x), wrt)
16+
end
17+
18+
"""
19+
$(TYPEDSIGNATURES)
20+
21+
Check if `x` is polynomial with respect to the variables in `wrt`.
22+
"""
23+
function is_polynomial(x, wrt)
24+
x = unwrap(x)
25+
symbolic_type(x) == NotSymbolic() && return true
26+
iscall(x) || return true
27+
contains_variable(x, wrt) || return true
28+
any(isequal(x), wrt) && return true
29+
30+
if operation(x) in (*, +, -)
31+
return all(y -> is_polynomial(y, wrt), arguments(x))
32+
end
33+
if operation(x) == (^)
34+
b, p = arguments(x)
35+
is_pow_integer = symtype(p) <: Integer
36+
if !is_pow_integer
37+
if symbolic_type(p) == NotSymbolic()
38+
@warn "In $x: Exponent $p is not an integer"
39+
else
40+
@warn "In $x: Exponent $p is not an integer. Use `@parameters p::Integer` to declare integer parameters."
41+
end
42+
end
43+
exponent_has_unknowns = contains_variable(p, wrt)
44+
if exponent_has_unknowns
45+
@warn "In $x: Exponent $p cannot contain unknowns of the system."
46+
end
47+
base_polynomial = is_polynomial(b, wrt)
48+
if !base_polynomial
49+
@warn "In $x: Base is not a polynomial"
50+
end
51+
return base_polynomial && !exponent_has_unknowns && is_pow_integer
52+
end
53+
@warn "In $x: Unrecognized operation $(operation(x)). Allowed polynomial operations are `*, +, -, ^`"
54+
return false
55+
end
56+
57+
"""
58+
$(TYPEDSIGNATURES)
59+
60+
Convert `expr` from a symbolics expression to one that uses `HomotopyContinuation.ModelKit`.
61+
"""
62+
function symbolics_to_hc(expr)
63+
if iscall(expr)
64+
if operation(expr) == getindex
65+
args = arguments(expr)
66+
return ModelKit.Variable(getname(args[1]), args[2:end]...)
67+
else
68+
return operation(expr)(symbolics_to_hc.(arguments(expr))...)
69+
end
70+
elseif symbolic_type(expr) == NotSymbolic()
71+
return expr
72+
else
73+
return ModelKit.Variable(getname(expr))
74+
end
75+
end
76+
77+
"""
78+
$(TYPEDEF)
79+
80+
A subtype of `HomotopyContinuation.AbstractSystem` used to solve `HomotopyContinuationProblem`s.
81+
"""
82+
struct MTKHomotopySystem{F, P, J, V} <: HomotopyContinuation.AbstractSystem
83+
"""
84+
The generated function for the residual of the polynomial system. In-place.
85+
"""
86+
f::F
87+
"""
88+
The parameter object.
89+
"""
90+
p::P
91+
"""
92+
The generated function for the jacobian of the polynomial system. In-place.
93+
"""
94+
jac::J
95+
"""
96+
The `HomotopyContinuation.ModelKit.Variable` representation of the unknowns of
97+
the system.
98+
"""
99+
vars::V
100+
"""
101+
The number of polynomials in the system. Must also be equal to `length(vars)`.
102+
"""
103+
nexprs::Int
104+
end
105+
106+
Base.size(sys::MTKHomotopySystem) = (sys.nexprs, length(sys.vars))
107+
ModelKit.variables(sys::MTKHomotopySystem) = sys.vars
108+
109+
function (sys::MTKHomotopySystem)(x, p = nothing)
110+
sys.f(x, sys.p)
111+
end
112+
113+
function ModelKit.evaluate!(u, sys::MTKHomotopySystem, x, p = nothing)
114+
sys.f(u, x, sys.p)
115+
end
116+
117+
function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = nothing)
118+
sys.f(u, x, sys.p)
119+
sys.jac(U, x, sys.p)
120+
end
121+
122+
SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p
123+
124+
"""
125+
$(TYPEDSIGNATURES)
126+
127+
Create a `HomotopyContinuationProblem` from a `NonlinearSystem` with polynomial equations.
128+
The problem will be solved by HomotopyContinuation.jl. The resultant `NonlinearSolution`
129+
will contain the polynomial root closest to the point specified by `u0map` (if real roots
130+
exist for the system).
131+
"""
132+
function MTK.HomotopyContinuationProblem(
133+
sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false,
134+
eval_module = ModelingToolkit, kwargs...)
135+
if !iscomplete(sys)
136+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
137+
end
138+
139+
dvs = unknowns(sys)
140+
eqs = equations(sys)
141+
142+
for eq in eqs
143+
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
144+
error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.")
145+
end
146+
end
147+
148+
nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap;
149+
jac = true, eval_expression, eval_module)
150+
151+
hvars = symbolics_to_hc.(dvs)
152+
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
153+
154+
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
155+
156+
return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn)
157+
end
158+
159+
"""
160+
$(TYPEDSIGNATURES)
161+
162+
Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always
163+
uses `HomotopyContinuation.jl`. All keyword arguments are forwarded to
164+
`HomotopyContinuation.solve`. The original solution as returned by `HomotopyContinuation.jl`
165+
will be available in the `.original` field of the returned `NonlinearSolution`.
166+
167+
All keyword arguments have their default values in HomotopyContinuation.jl, except
168+
`show_progress` which defaults to `false`.
169+
"""
170+
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem,
171+
alg = nothing; show_progress = false, kwargs...)
172+
sol = HomotopyContinuation.solve(
173+
prob.homotopy_continuation_system; show_progress, kwargs...)
174+
realsols = HomotopyContinuation.results(sol; only_real = true)
175+
if isempty(realsols)
176+
u = state_values(prob)
177+
resid = prob.homotopy_continuation_system(u)
178+
retcode = SciMLBase.ReturnCode.ConvergenceFailure
179+
else
180+
distance, idx = findmin(realsols) do result
181+
norm(result.solution - state_values(prob))
182+
end
183+
u = real.(realsols[idx].solution)
184+
resid = prob.homotopy_continuation_system(u)
185+
retcode = SciMLBase.ReturnCode.Success
186+
end
187+
188+
return SciMLBase.build_solution(
189+
prob, :HomotopyContinuation, u, resid; retcode, original = sol)
190+
end
191+
192+
end

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ using Reexport
5454
using RecursiveArrayTools
5555
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
5656
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
57+
import CommonSolve
5758

5859
using RuntimeGeneratedFunctions
5960
using RuntimeGeneratedFunctions: drop_expr
@@ -281,4 +282,6 @@ export Clock, SolverStepClock, TimeDomain
281282

282283
export MTKParameters, reorder_dimension_by_tunables!, reorder_dimension_by_tunables
283284

285+
export HomotopyContinuationProblem
286+
284287
end # module

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,3 +565,55 @@ function Base.:(==)(sys1::NonlinearSystem, sys2::NonlinearSystem)
565565
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
566566
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
567567
end
568+
569+
"""
570+
$(TYPEDEF)
571+
572+
A type of Nonlinear problem which specializes on polynomial systems and uses
573+
HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to
574+
create and solve.
575+
"""
576+
struct HomotopyContinuationProblem{uType, H, O} <:
577+
SciMLBase.AbstractNonlinearProblem{uType, true}
578+
"""
579+
The initial values of states in the system. If there are multiple real roots of
580+
the system, the one closest to this point is returned.
581+
"""
582+
u0::uType
583+
"""
584+
A subtype of `HomotopyContinuation.AbstractSystem` to solve. Also contains the
585+
parameter object.
586+
"""
587+
homotopy_continuation_system::H
588+
"""
589+
The `NonlinearSystem` used to create this problem. Used for symbolic indexing.
590+
"""
591+
sys::NonlinearSystem
592+
"""
593+
A function which generates and returns observed expressions for the given system.
594+
"""
595+
obsfn::O
596+
end
597+
598+
function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...)
599+
error("HomotopyContinuation.jl is required to create and solve `HomotopyContinuationProblem`s. Please run `Pkg.add(\"HomotopyContinuation\")` to continue.")
600+
end
601+
602+
SymbolicIndexingInterface.symbolic_container(p::HomotopyContinuationProblem) = p.sys
603+
SymbolicIndexingInterface.state_values(p::HomotopyContinuationProblem) = p.u0
604+
function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, args...)
605+
set_state!(p.u0, args...)
606+
end
607+
function SymbolicIndexingInterface.parameter_values(p::HomotopyContinuationProblem)
608+
parameter_values(p.homotopy_continuation_system)
609+
end
610+
function SymbolicIndexingInterface.set_parameter!(p::HomotopyContinuationProblem, args...)
611+
set_parameter!(parameter_values(p), args...)
612+
end
613+
function SymbolicIndexingInterface.observed(p::HomotopyContinuationProblem, sym)
614+
if p.obsfn !== nothing
615+
return p.obsfn(sym)
616+
else
617+
return SymbolicIndexingInterface.observed(p.sys, sym)
618+
end
619+
end

test/downstream/linearize.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,13 @@ lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(ssys), desired_order)
121121
lsyss, _ = ModelingToolkit.linearize_symbolic(pid, [reference.u, measurement.u],
122122
[ctr_output.u])
123123

124-
@test substitute(
124+
@test ModelingToolkit.fixpoint_sub(
125125
lsyss.A, ModelingToolkit.defaults_and_guesses(pid)) == lsys.A
126-
@test substitute(
126+
@test ModelingToolkit.fixpoint_sub(
127127
lsyss.B, ModelingToolkit.defaults_and_guesses(pid)) == lsys.B
128-
@test substitute(
128+
@test ModelingToolkit.fixpoint_sub(
129129
lsyss.C, ModelingToolkit.defaults_and_guesses(pid)) == lsys.C
130-
@test substitute(
130+
@test ModelingToolkit.fixpoint_sub(
131131
lsyss.D, ModelingToolkit.defaults_and_guesses(pid)) == lsys.D
132132

133133
# Test with the reverse desired unknown order as well to verify that similarity transform and reoreder_unknowns really works

test/extensions/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
33
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
55
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6+
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
67
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
78
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
89
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
910
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1011
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
1112
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
13+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1214
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

0 commit comments

Comments
 (0)