Skip to content

Commit 52e3647

Browse files
feat: initial implementation of HomotopyContinuation interface
1 parent 816fde7 commit 52e3647

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
11+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1112
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1213
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1314
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -61,12 +62,14 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6162
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6263
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6364
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
65+
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
6466
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
6567

6668
[extensions]
6769
MTKBifurcationKitExt = "BifurcationKit"
6870
MTKChainRulesCoreExt = "ChainRulesCore"
6971
MTKDeepDiffsExt = "DeepDiffs"
72+
MTKHomotopyContinuationExt = "HomotopyContinuation"
7073
MTKLabelledArraysExt = "LabelledArrays"
7174

7275
[compat]
@@ -76,6 +79,7 @@ BifurcationKit = "0.3"
7679
BlockArrays = "1.1"
7780
ChainRulesCore = "1"
7881
Combinatorics = "1"
82+
CommonSolve = "0.2.4"
7983
Compat = "3.42, 4"
8084
ConstructionBase = "1"
8185
DataInterpolations = "6.4"
@@ -97,6 +101,7 @@ ForwardDiff = "0.10.3"
97101
FunctionWrappers = "1.1"
98102
FunctionWrappersWrappers = "0.1"
99103
Graphs = "1.5.2"
104+
HomotopyContinuation = "2.11"
100105
InteractiveUtils = "1"
101106
JuliaFormatter = "1.0.47"
102107
JumpProcesses = "9.13.1"

ext/MTKHomotopyContinuationExt.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
module MTKHomotopyContinuationExt
2+
3+
using ModelingToolkit
4+
using ModelingToolkit.Symbolics: unwrap
5+
using ModelingToolkit.SymbolicIndexingInterface
6+
using HomotopyContinuation
7+
using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0,
8+
get_u0_p, check_eqs_u0, CommonSolve
9+
10+
function contains_variable(x, wrt)
11+
any(isequal(x), wrt) && return true
12+
istree(x) || return false
13+
return any(y -> contains_variable(y, wrt), arguments(x))
14+
end
15+
16+
function is_polynomial(x, wrt)
17+
x = unwrap(x)
18+
symbolic_type(x) == NotSymbolic() && return true
19+
istree(x) || return true
20+
contains_variable(x, wrt) || return true
21+
22+
if operation(x) in (*, +, -)
23+
return all(y -> is_polynomial(y, wrt), arguments(x))
24+
end
25+
if operation(x) == (^)
26+
b, p = arguments(x)
27+
return is_polynomial(b, wrt) && !contains_variable(p, wrt)
28+
end
29+
return false
30+
end
31+
32+
function symbolics_to_hc(expr)
33+
if iscall(expr)
34+
if operation(expr) == getindex
35+
args = arguments(expr)
36+
return ModelKit.Variable(getname(args[1]), args[2:end]...)
37+
else
38+
return operation(expr)(symbolics_to_hc.(arguments(expr))...)
39+
end
40+
elseif symbolic_type(expr) == NotSymbolic()
41+
return expr
42+
else
43+
return ModelKit.Variable(getname(expr))
44+
end
45+
end
46+
47+
struct MTKHomotopySystem{F, P, J, V} <: HomotopyContinuation.AbstractSystem
48+
f::F
49+
p::P
50+
jac::J
51+
vars::V
52+
nexprs::Int
53+
end
54+
55+
Base.size(sys::MTKHomotopySystem) = (sys.nexprs, length(sys.vars))
56+
ModelKit.variables(sys::MTKHomotopySystem) = sys.vars
57+
58+
function (sys::MTKHomotopySystem)(x, p = nothing)
59+
sys.f(x, sys.p)
60+
end
61+
62+
function ModelKit.evaluate!(u, sys::MTKHomotopySystem, x, p = nothing)
63+
sys.f(u, x, sys.p)
64+
end
65+
66+
function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = nothing)
67+
sys.f(u, x, sys.p)
68+
sys.jac(U, x, sys.p)
69+
end
70+
71+
function ModelingToolkit.HomotopyContinuationProblem(
72+
sys::NonlinearSystem, u0map, parammap; compile = :all, kwargs...)
73+
if !iscomplete(sys)
74+
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
75+
end
76+
77+
dvs = unknowns(sys)
78+
ps = parameters(sys)
79+
eqs = equations(sys)
80+
81+
for eq in eqs
82+
if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)
83+
error("Equation $eq is not a polynomial in the unknowns")
84+
end
85+
end
86+
87+
nlfn = NonlinearFunction(sys; jac = true)
88+
hvars = symbolics_to_hc.(dvs)
89+
90+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
91+
u0, defs = get_u0(sys, u0map, parammap)
92+
check_eqs_u0(eqs, dvs, u0; kwargs...)
93+
p = MTKParameters(sys, parammap, u0map)
94+
else
95+
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
96+
check_eqs_u0(eqs, dvs, u0; kwargs...)
97+
end
98+
99+
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
100+
101+
prob = ModelingToolkit.HomotopyContinuationProblem{typeof(mtkhsys), typeof(u0)}(
102+
sys, mtkhsys, u0)
103+
end
104+
105+
function CommonSolve.solve(prob::ModelingToolkit.HomotopyContinuationProblem; kwargs...)
106+
sol = HomotopyContinuation.solve(prob.hcsys; kwargs...)
107+
rsols = HomotopyContinuation.real_solutions(sol)
108+
rsol = findmin(rsols) do val
109+
norm(prob.u0 - val)
110+
end
111+
return rsol
112+
end
113+
114+
end

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ using Reexport
5454
using RecursiveArrayTools
5555
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
5656
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
57+
import CommonSolve
5758

5859
using RuntimeGeneratedFunctions
5960
using RuntimeGeneratedFunctions: drop_expr

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,25 @@ function Base.:(==)(sys1::NonlinearSystem, sys2::NonlinearSystem)
598598
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
599599
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
600600
end
601+
602+
struct HomotopyContinuationProblem{H, U}
603+
sys::NonlinearSystem
604+
hcsys::H
605+
u0::U
606+
end
607+
608+
function HomotopyContinuationProblem(args...; kwargs...)
609+
error("Requires HomotopyContinuationExt")
610+
end
611+
612+
SymbolicIndexingInterface.symbolic_container(p::HomotopyContinuationProblem) = p.sys
613+
SymbolicIndexingInterface.state_values(p::HomotopyContinuationProblem) = p.u0
614+
function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, args...)
615+
set_state!(p.u0, args...)
616+
end
617+
function SymbolicIndexingInterface.parameter_values(p::HomotopyContinuationProblem)
618+
parameter_values(p.hcsys)
619+
end
620+
function SymbolicIndexingInterface.set_parameter!(p::HomotopyContinuationProblem, args...)
621+
set_parameter!(p.hcsys, args...)
622+
end

0 commit comments

Comments
 (0)