Skip to content

Commit 1211cb8

Browse files
refactor: use SciMLBase abstract types, return solution object
1 parent 52e3647 commit 1211cb8

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

ext/MTKHomotopyContinuationExt.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module MTKHomotopyContinuationExt
22

33
using ModelingToolkit
4+
using ModelingToolkit.SciMLBase
45
using ModelingToolkit.Symbolics: unwrap
56
using ModelingToolkit.SymbolicIndexingInterface
67
using HomotopyContinuation
@@ -98,17 +99,26 @@ function ModelingToolkit.HomotopyContinuationProblem(
9899

99100
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
100101

101-
prob = ModelingToolkit.HomotopyContinuationProblem{typeof(mtkhsys), typeof(u0)}(
102-
sys, mtkhsys, u0)
102+
return ModelingToolkit.HomotopyContinuationProblem(u0, mtkhsys, sys)
103103
end
104104

105105
function CommonSolve.solve(prob::ModelingToolkit.HomotopyContinuationProblem; kwargs...)
106-
sol = HomotopyContinuation.solve(prob.hcsys; kwargs...)
107-
rsols = HomotopyContinuation.real_solutions(sol)
108-
rsol = findmin(rsols) do val
109-
norm(prob.u0 - val)
106+
sol = HomotopyContinuation.solve(prob.homotopy_continuation_system; kwargs...)
107+
realsols = HomotopyContinuation.results(sol; only_real = true)
108+
if isempty(realsols)
109+
u = state_values(prob)
110+
resid = prob.homotopy_continuation_system(u)
111+
retcode = SciMLBase.ReturnCode.ConvergenceFailure
112+
else
113+
distance, idx = findmin(realsols) do result
114+
norm(result.solution - state_values(prob))
115+
end
116+
u = real.(realsols[idx].solution)
117+
resid = prob.homotopy_continuation_system(u)
118+
retcode = SciMLBase.ReturnCode.Success
110119
end
111-
return rsol
120+
121+
return SciMLBase.build_solution(prob, :HomotopyContinuation, u, resid; retcode)
112122
end
113123

114124
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,10 +599,10 @@ function Base.:(==)(sys1::NonlinearSystem, sys2::NonlinearSystem)
599599
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
600600
end
601601

602-
struct HomotopyContinuationProblem{H, U}
602+
struct HomotopyContinuationProblem{uType, H} <: SciMLBase.AbstractNonlinearProblem{uType, true}
603+
u0::uType
604+
homotopy_continuation_system::H
603605
sys::NonlinearSystem
604-
hcsys::H
605-
u0::U
606606
end
607607

608608
function HomotopyContinuationProblem(args...; kwargs...)

0 commit comments

Comments
 (0)