|
| 1 | +module MTKHomotopyContinuationExt |
| 2 | + |
| 3 | +using ModelingToolkit |
| 4 | +using ModelingToolkit.SciMLBase |
| 5 | +using ModelingToolkit.Symbolics: unwrap |
| 6 | +using ModelingToolkit.SymbolicIndexingInterface |
| 7 | +using HomotopyContinuation |
| 8 | +using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0, |
| 9 | + get_u0_p, check_eqs_u0, CommonSolve |
| 10 | + |
| 11 | +const MTK = ModelingToolkit |
| 12 | + |
| 13 | +function contains_variable(x, wrt) |
| 14 | + any(isequal(x), wrt) && return true |
| 15 | + iscall(x) || return false |
| 16 | + return any(y -> contains_variable(y, wrt), arguments(x)) |
| 17 | +end |
| 18 | + |
| 19 | +function is_polynomial(x, wrt) |
| 20 | + x = unwrap(x) |
| 21 | + symbolic_type(x) == NotSymbolic() && return true |
| 22 | + iscall(x) || return true |
| 23 | + contains_variable(x, wrt) || return true |
| 24 | + any(isequal(x), wrt) && return true |
| 25 | + |
| 26 | + if operation(x) in (*, +, -) |
| 27 | + return all(y -> is_polynomial(y, wrt), arguments(x)) |
| 28 | + end |
| 29 | + if operation(x) == (^) |
| 30 | + b, p = arguments(x) |
| 31 | + return is_polynomial(b, wrt) && !contains_variable(p, wrt) |
| 32 | + end |
| 33 | + return false |
| 34 | +end |
| 35 | + |
| 36 | +function symbolics_to_hc(expr) |
| 37 | + if iscall(expr) |
| 38 | + if operation(expr) == getindex |
| 39 | + args = arguments(expr) |
| 40 | + return ModelKit.Variable(getname(args[1]), args[2:end]...) |
| 41 | + else |
| 42 | + return operation(expr)(symbolics_to_hc.(arguments(expr))...) |
| 43 | + end |
| 44 | + elseif symbolic_type(expr) == NotSymbolic() |
| 45 | + return expr |
| 46 | + else |
| 47 | + return ModelKit.Variable(getname(expr)) |
| 48 | + end |
| 49 | +end |
| 50 | + |
| 51 | +struct MTKHomotopySystem{F, P, J, V} <: HomotopyContinuation.AbstractSystem |
| 52 | + f::F |
| 53 | + p::P |
| 54 | + jac::J |
| 55 | + vars::V |
| 56 | + nexprs::Int |
| 57 | +end |
| 58 | + |
| 59 | +Base.size(sys::MTKHomotopySystem) = (sys.nexprs, length(sys.vars)) |
| 60 | +ModelKit.variables(sys::MTKHomotopySystem) = sys.vars |
| 61 | + |
| 62 | +function (sys::MTKHomotopySystem)(x, p = nothing) |
| 63 | + sys.f(x, sys.p) |
| 64 | +end |
| 65 | + |
| 66 | +function ModelKit.evaluate!(u, sys::MTKHomotopySystem, x, p = nothing) |
| 67 | + sys.f(u, x, sys.p) |
| 68 | +end |
| 69 | + |
| 70 | +function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = nothing) |
| 71 | + sys.f(u, x, sys.p) |
| 72 | + sys.jac(U, x, sys.p) |
| 73 | +end |
| 74 | + |
| 75 | +SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p |
| 76 | + |
| 77 | +function MTK.HomotopyContinuationProblem( |
| 78 | + sys::NonlinearSystem, u0map, parammap; compile = :all, eval_expression = false, eval_module = ModelingToolkit, kwargs...) |
| 79 | + if !iscomplete(sys) |
| 80 | + error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`") |
| 81 | + end |
| 82 | + |
| 83 | + dvs = unknowns(sys) |
| 84 | + eqs = equations(sys) |
| 85 | + |
| 86 | + for eq in eqs |
| 87 | + if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs) |
| 88 | + error("Equation $eq is not a polynomial in the unknowns") |
| 89 | + end |
| 90 | + end |
| 91 | + |
| 92 | + nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap; |
| 93 | + jac = true, eval_expression, eval_module) |
| 94 | + |
| 95 | + hvars = symbolics_to_hc.(dvs) |
| 96 | + mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs)) |
| 97 | + |
| 98 | + obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module) |
| 99 | + |
| 100 | + return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn) |
| 101 | +end |
| 102 | + |
| 103 | +function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem; kwargs...) |
| 104 | + sol = HomotopyContinuation.solve(prob.homotopy_continuation_system; kwargs...) |
| 105 | + realsols = HomotopyContinuation.results(sol; only_real = true) |
| 106 | + if isempty(realsols) |
| 107 | + u = state_values(prob) |
| 108 | + resid = prob.homotopy_continuation_system(u) |
| 109 | + retcode = SciMLBase.ReturnCode.ConvergenceFailure |
| 110 | + else |
| 111 | + distance, idx = findmin(realsols) do result |
| 112 | + norm(result.solution - state_values(prob)) |
| 113 | + end |
| 114 | + u = real.(realsols[idx].solution) |
| 115 | + resid = prob.homotopy_continuation_system(u) |
| 116 | + retcode = SciMLBase.ReturnCode.Success |
| 117 | + end |
| 118 | + |
| 119 | + return SciMLBase.build_solution( |
| 120 | + prob, :HomotopyContinuation, u, resid; retcode, original = sol) |
| 121 | +end |
| 122 | + |
| 123 | +end |
0 commit comments