|
| 1 | +module MTKHomotopyContinuationExt |
| 2 | + |
| 3 | +using ModelingToolkit |
| 4 | +using ModelingToolkit.SciMLBase |
| 5 | +using ModelingToolkit.Symbolics: unwrap, symtype |
| 6 | +using ModelingToolkit.SymbolicIndexingInterface |
| 7 | +using ModelingToolkit.DocStringExtensions |
| 8 | +using HomotopyContinuation |
| 9 | +using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0, |
| 10 | + get_u0_p, check_eqs_u0, CommonSolve |
| 11 | + |
| 12 | +const MTK = ModelingToolkit |
| 13 | + |
| 14 | +function contains_variable(x, wrt) |
| 15 | + any(y -> occursin(y, x), wrt) |
| 16 | +end |
| 17 | + |
| 18 | +""" |
| 19 | +$(TYPEDSIGNATURES) |
| 20 | +
|
| 21 | +Check if `x` is polynomial with respect to the variables in `wrt`. |
| 22 | +""" |
| 23 | +function is_polynomial(x, wrt) |
| 24 | + x = unwrap(x) |
| 25 | + symbolic_type(x) == NotSymbolic() && return true |
| 26 | + iscall(x) || return true |
| 27 | + contains_variable(x, wrt) || return true |
| 28 | + any(isequal(x), wrt) && return true |
| 29 | + |
| 30 | + if operation(x) in (*, +, -) |
| 31 | + return all(y -> is_polynomial(y, wrt), arguments(x)) |
| 32 | + end |
| 33 | + if operation(x) == (^) |
| 34 | + b, p = arguments(x) |
| 35 | + is_pow_integer = symtype(p) <: Integer |
| 36 | + if !is_pow_integer |
| 37 | + if symbolic_type(p) == NotSymbolic() |
| 38 | + @warn "In $x: Exponent $p is not an integer" |
| 39 | + else |
| 40 | + @warn "In $x: Exponent $p is not an integer. Use `@parameters p::Integer` to declare integer parameters." |
| 41 | + end |
| 42 | + end |
| 43 | + exponent_has_unknowns = contains_variable(p, wrt) |
| 44 | + if exponent_has_unknowns |
| 45 | + @warn "In $x: Exponent $p cannot contain unknowns of the system." |
| 46 | + end |
| 47 | + base_polynomial = is_polynomial(b, wrt) |
| 48 | + if !base_polynomial |
| 49 | + @warn "In $x: Base is not a polynomial" |
| 50 | + end |
| 51 | + return base_polynomial && !exponent_has_unknowns && is_pow_integer |
| 52 | + end |
| 53 | + @warn "In $x: Unrecognized operation $(operation(x)). Allowed polynomial operations are `*, +, -, ^`" |
| 54 | + return false |
| 55 | +end |
| 56 | + |
| 57 | +""" |
| 58 | +$(TYPEDSIGNATURES) |
| 59 | +
|
| 60 | +Convert `expr` from a symbolics expression to one that uses `HomotopyContinuation.ModelKit`. |
| 61 | +""" |
| 62 | +function symbolics_to_hc(expr) |
| 63 | + if iscall(expr) |
| 64 | + if operation(expr) == getindex |
| 65 | + args = arguments(expr) |
| 66 | + return ModelKit.Variable(getname(args[1]), args[2:end]...) |
| 67 | + else |
| 68 | + return operation(expr)(symbolics_to_hc.(arguments(expr))...) |
| 69 | + end |
| 70 | + elseif symbolic_type(expr) == NotSymbolic() |
| 71 | + return expr |
| 72 | + else |
| 73 | + return ModelKit.Variable(getname(expr)) |
| 74 | + end |
| 75 | +end |
| 76 | + |
| 77 | +""" |
| 78 | +$(TYPEDEF) |
| 79 | +
|
| 80 | +A subtype of `HomotopyContinuation.AbstractSystem` used to solve `HomotopyContinuationProblem`s. |
| 81 | +""" |
| 82 | +struct MTKHomotopySystem{F, P, J, V} <: HomotopyContinuation.AbstractSystem |
| 83 | + """ |
| 84 | + The generated function for the residual of the polynomial system. In-place. |
| 85 | + """ |
| 86 | + f::F |
| 87 | + """ |
| 88 | + The parameter object. |
| 89 | + """ |
| 90 | + p::P |
| 91 | + """ |
| 92 | + The generated function for the jacobian of the polynomial system. In-place. |
| 93 | + """ |
| 94 | + jac::J |
| 95 | + """ |
| 96 | + The `HomotopyContinuation.ModelKit.Variable` representation of the unknowns of |
| 97 | + the system. |
| 98 | + """ |
| 99 | + vars::V |
| 100 | + """ |
| 101 | + The number of polynomials in the system. Must also be equal to `length(vars)`. |
| 102 | + """ |
| 103 | + nexprs::Int |
| 104 | +end |
| 105 | + |
| 106 | +Base.size(sys::MTKHomotopySystem) = (sys.nexprs, length(sys.vars)) |
| 107 | +ModelKit.variables(sys::MTKHomotopySystem) = sys.vars |
| 108 | + |
| 109 | +function (sys::MTKHomotopySystem)(x, p = nothing) |
| 110 | + sys.f(x, sys.p) |
| 111 | +end |
| 112 | + |
| 113 | +function ModelKit.evaluate!(u, sys::MTKHomotopySystem, x, p = nothing) |
| 114 | + sys.f(u, x, sys.p) |
| 115 | +end |
| 116 | + |
| 117 | +function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = nothing) |
| 118 | + sys.f(u, x, sys.p) |
| 119 | + sys.jac(U, x, sys.p) |
| 120 | +end |
| 121 | + |
| 122 | +SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p |
| 123 | + |
| 124 | +""" |
| 125 | + $(TYPEDSIGNATURES) |
| 126 | +
|
| 127 | +Create a `HomotopyContinuationProblem` from a `NonlinearSystem` with polynomial equations. |
| 128 | +The problem will be solved by HomotopyContinuation.jl. The resultant `NonlinearSolution` |
| 129 | +will contain the polynomial root closest to the point specified by `u0map` (if real roots |
| 130 | +exist for the system). |
| 131 | +""" |
| 132 | +function MTK.HomotopyContinuationProblem( |
| 133 | + sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false, |
| 134 | + eval_module = ModelingToolkit, kwargs...) |
| 135 | + if !iscomplete(sys) |
| 136 | + error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`") |
| 137 | + end |
| 138 | + |
| 139 | + dvs = unknowns(sys) |
| 140 | + eqs = equations(sys) |
| 141 | + |
| 142 | + for eq in eqs |
| 143 | + if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs) |
| 144 | + error("Equation $eq is not a polynomial in the unknowns. See warnings for further details.") |
| 145 | + end |
| 146 | + end |
| 147 | + |
| 148 | + nlfn, u0, p = MTK.process_SciMLProblem(NonlinearFunction{true}, sys, u0map, parammap; |
| 149 | + jac = true, eval_expression, eval_module) |
| 150 | + |
| 151 | + hvars = symbolics_to_hc.(dvs) |
| 152 | + mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs)) |
| 153 | + |
| 154 | + obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module) |
| 155 | + |
| 156 | + return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn) |
| 157 | +end |
| 158 | + |
| 159 | +""" |
| 160 | +$(TYPEDSIGNATURES) |
| 161 | +
|
| 162 | +Solve a `HomotopyContinuationProblem`. Ignores the algorithm passed to it, and always |
| 163 | +uses `HomotopyContinuation.jl`. All keyword arguments are forwarded to |
| 164 | +`HomotopyContinuation.solve`. The original solution as returned by `HomotopyContinuation.jl` |
| 165 | +will be available in the `.original` field of the returned `NonlinearSolution`. |
| 166 | +
|
| 167 | +All keyword arguments have their default values in HomotopyContinuation.jl, except |
| 168 | +`show_progress` which defaults to `false`. |
| 169 | +""" |
| 170 | +function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem, |
| 171 | + alg = nothing; show_progress = false, kwargs...) |
| 172 | + sol = HomotopyContinuation.solve( |
| 173 | + prob.homotopy_continuation_system; show_progress, kwargs...) |
| 174 | + realsols = HomotopyContinuation.results(sol; only_real = true) |
| 175 | + if isempty(realsols) |
| 176 | + u = state_values(prob) |
| 177 | + resid = prob.homotopy_continuation_system(u) |
| 178 | + retcode = SciMLBase.ReturnCode.ConvergenceFailure |
| 179 | + else |
| 180 | + distance, idx = findmin(realsols) do result |
| 181 | + norm(result.solution - state_values(prob)) |
| 182 | + end |
| 183 | + u = real.(realsols[idx].solution) |
| 184 | + resid = prob.homotopy_continuation_system(u) |
| 185 | + retcode = SciMLBase.ReturnCode.Success |
| 186 | + end |
| 187 | + |
| 188 | + return SciMLBase.build_solution( |
| 189 | + prob, :HomotopyContinuation, u, resid; retcode, original = sol) |
| 190 | +end |
| 191 | + |
| 192 | +end |
0 commit comments