Skip to content

Commit 05e89ab

Browse files
feat: initial implementation of HomotopyContinuation interface
1 parent 41de08f commit 05e89ab

File tree

4 files changed

+161
-0
lines changed

4 files changed

+161
-0
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
11+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1112
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1213
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1314
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -61,12 +62,14 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6162
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6263
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6364
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
65+
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
6466
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
6567

6668
[extensions]
6769
MTKBifurcationKitExt = "BifurcationKit"
6870
MTKChainRulesCoreExt = "ChainRulesCore"
6971
MTKDeepDiffsExt = "DeepDiffs"
72+
MTKHomotopyContinuationExt = "HomotopyContinuation"
7073
MTKLabelledArraysExt = "LabelledArrays"
7174

7275
[compat]
@@ -76,6 +79,7 @@ BifurcationKit = "0.3"
7679
BlockArrays = "1.1"
7780
ChainRulesCore = "1"
7881
Combinatorics = "1"
82+
CommonSolve = "0.2.4"
7983
Compat = "3.42, 4"
8084
ConstructionBase = "1"
8185
DataInterpolations = "6.4"
@@ -97,6 +101,7 @@ ForwardDiff = "0.10.3"
97101
FunctionWrappers = "1.1"
98102
FunctionWrappersWrappers = "0.1"
99103
Graphs = "1.5.2"
104+
HomotopyContinuation = "2.11"
100105
InteractiveUtils = "1"
101106
JuliaFormatter = "1.0.47"
102107
JumpProcesses = "9.13.1"

ext/MTKHomotopyContinuationExt.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
module MTKHomotopyContinuationExt
2+
3+
using ModelingToolkit
4+
using ModelingToolkit.SciMLBase
5+
using ModelingToolkit.Symbolics: unwrap
6+
using ModelingToolkit.SymbolicIndexingInterface
7+
using HomotopyContinuation
8+
using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0,
9+
get_u0_p, check_eqs_u0, CommonSolve
10+
11+
const MTK = ModelingToolkit
12+
13+
function contains_variable(x, wrt)
14+
any(isequal(x), wrt) && return true
15+
iscall(x) || return false
16+
return any(y -> contains_variable(y, wrt), arguments(x))
17+
end
18+
19+
function is_polynomial(x, wrt)
20+
x = unwrap(x)
21+
symbolic_type(x) == NotSymbolic() && return true
22+
iscall(x) || return true
23+
contains_variable(x, wrt) || return true
24+
any(isequal(x), wrt) && return true
25+
26+
if operation(x) in (*, +, -)
27+
return all(y -> is_polynomial(y, wrt), arguments(x))
28+
end
29+
if operation(x) == (^)
30+
b, p = arguments(x)
31+
return is_polynomial(b, wrt) && !contains_variable(p, wrt)
32+
end
33+
return false
34+
end
35+
36+
function symbolics_to_hc(expr)
37+
if iscall(expr)
38+
if operation(expr) == getindex
39+
args = arguments(expr)
40+
return ModelKit.Variable(getname(args[1]), args[2:end]...)
41+
else
42+
return operation(expr)(symbolics_to_hc.(arguments(expr))...)
43+
end
44+
elseif symbolic_type(expr) == NotSymbolic()
45+
return expr
46+
else
47+
return ModelKit.Variable(getname(expr))
48+
end
49+
end
50+
51+
struct MTKHomotopySystem{F, P, J, V} <: HomotopyContinuation.AbstractSystem
52+
f::F
53+
p::P
54+
jac::J
55+
vars::V
56+
nexprs::Int
57+
end
58+
59+
Base.size(sys::MTKHomotopySystem) = (sys.nexprs, length(sys.vars))
60+
ModelKit.variables(sys::MTKHomotopySystem) = sys.vars
61+
62+
function (sys::MTKHomotopySystem)(x, p = nothing)
63+
sys.f(x, sys.p)
64+
end
65+
66+
function ModelKit.evaluate!(u, sys::MTKHomotopySystem, x, p = nothing)
67+
sys.f(u, x, sys.p)
68+
end
69+
70+
function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = nothing)
71+
sys.f(u, x, sys.p)
72+
sys.jac(U, x, sys.p)
73+
end
74+
75+
SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p
76+
77+
function MTK.HomotopyContinuationProblem(
78+
sys::NonlinearSystem, u0map, parammap; compile = :all, eval_expression = false, eval_module = ModelingToolkit, kwargs...)
79+
if !iscomplete(sys)
80+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
81+
end
82+
83+
dvs = unknowns(sys)
84+
eqs = equations(sys)
85+
86+
for eq in eqs
87+
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
88+
error("Equation $eq is not a polynomial in the unknowns")
89+
end
90+
end
91+
92+
nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap;
93+
jac = true, eval_expression, eval_module)
94+
95+
hvars = symbolics_to_hc.(dvs)
96+
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
97+
98+
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
99+
100+
return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn)
101+
end
102+
103+
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem; kwargs...)
104+
sol = HomotopyContinuation.solve(prob.homotopy_continuation_system; kwargs...)
105+
realsols = HomotopyContinuation.results(sol; only_real = true)
106+
if isempty(realsols)
107+
u = state_values(prob)
108+
resid = prob.homotopy_continuation_system(u)
109+
retcode = SciMLBase.ReturnCode.ConvergenceFailure
110+
else
111+
distance, idx = findmin(realsols) do result
112+
norm(result.solution - state_values(prob))
113+
end
114+
u = real.(realsols[idx].solution)
115+
resid = prob.homotopy_continuation_system(u)
116+
retcode = SciMLBase.ReturnCode.Success
117+
end
118+
119+
return SciMLBase.build_solution(
120+
prob, :HomotopyContinuation, u, resid; retcode, original = sol)
121+
end
122+
123+
end

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ using Reexport
5454
using RecursiveArrayTools
5555
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
5656
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
57+
import CommonSolve
5758

5859
using RuntimeGeneratedFunctions
5960
using RuntimeGeneratedFunctions: drop_expr
@@ -281,4 +282,6 @@ export Clock, SolverStepClock, TimeDomain
281282

282283
export MTKParameters, reorder_dimension_by_tunables!, reorder_dimension_by_tunables
283284

285+
export HomotopyContinuationProblem
286+
284287
end # module

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,3 +565,33 @@ function Base.:(==)(sys1::NonlinearSystem, sys2::NonlinearSystem)
565565
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
566566
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
567567
end
568+
569+
struct HomotopyContinuationProblem{uType, H, O} <: SciMLBase.AbstractNonlinearProblem{uType, true}
570+
u0::uType
571+
homotopy_continuation_system::H
572+
sys::NonlinearSystem
573+
obsfn::O
574+
end
575+
576+
function HomotopyContinuationProblem(args...; kwargs...)
577+
error("Requires HomotopyContinuationExt")
578+
end
579+
580+
SymbolicIndexingInterface.symbolic_container(p::HomotopyContinuationProblem) = p.sys
581+
SymbolicIndexingInterface.state_values(p::HomotopyContinuationProblem) = p.u0
582+
function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, args...)
583+
set_state!(p.u0, args...)
584+
end
585+
function SymbolicIndexingInterface.parameter_values(p::HomotopyContinuationProblem)
586+
parameter_values(p.homotopy_continuation_system)
587+
end
588+
function SymbolicIndexingInterface.set_parameter!(p::HomotopyContinuationProblem, args...)
589+
set_parameter!(parameter_values(p), args...)
590+
end
591+
function SymbolicIndexingInterface.observed(p::HomotopyContinuationProblem, sym)
592+
if p.obsfn !== nothing
593+
return p.obsfn(sym)
594+
else
595+
return SymbolicIndexingInterface.observed(p.sys, sym)
596+
end
597+
end

0 commit comments

Comments
 (0)