| 
 | 1 | +module MTKHomotopyContinuationExt  | 
 | 2 | + | 
 | 3 | +using ModelingToolkit  | 
 | 4 | +using ModelingToolkit.Symbolics: unwrap  | 
 | 5 | +using ModelingToolkit.SymbolicIndexingInterface  | 
 | 6 | +using HomotopyContinuation  | 
 | 7 | +using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0,  | 
 | 8 | +                       get_u0_p, check_eqs_u0, CommonSolve  | 
 | 9 | + | 
 | 10 | +function contains_variable(x, wrt)  | 
 | 11 | +    any(isequal(x), wrt) && return true  | 
 | 12 | +    istree(x) || return false  | 
 | 13 | +    return any(y -> contains_variable(y, wrt), arguments(x))  | 
 | 14 | +end  | 
 | 15 | + | 
 | 16 | +function is_polynomial(x, wrt)  | 
 | 17 | +    x = unwrap(x)  | 
 | 18 | +    symbolic_type(x) == NotSymbolic() && return true  | 
 | 19 | +    istree(x) || return true  | 
 | 20 | +    contains_variable(x, wrt) || return true  | 
 | 21 | + | 
 | 22 | +    if operation(x) in (*, +, -)  | 
 | 23 | +        return all(y -> is_polynomial(y, wrt), arguments(x))  | 
 | 24 | +    end  | 
 | 25 | +    if operation(x) == (^)  | 
 | 26 | +        b, p = arguments(x)  | 
 | 27 | +        return is_polynomial(b, wrt) && !contains_variable(p, wrt)  | 
 | 28 | +    end  | 
 | 29 | +    return false  | 
 | 30 | +end  | 
 | 31 | + | 
 | 32 | +function symbolics_to_hc(expr)  | 
 | 33 | +    if iscall(expr)  | 
 | 34 | +        if operation(expr) == getindex  | 
 | 35 | +            args = arguments(expr)  | 
 | 36 | +            return ModelKit.Variable(getname(args[1]), args[2:end]...)  | 
 | 37 | +        else  | 
 | 38 | +            return operation(expr)(symbolics_to_hc.(arguments(expr))...)  | 
 | 39 | +        end  | 
 | 40 | +    elseif symbolic_type(expr) == NotSymbolic()  | 
 | 41 | +        return expr  | 
 | 42 | +    else  | 
 | 43 | +        return ModelKit.Variable(getname(expr))  | 
 | 44 | +    end  | 
 | 45 | +end  | 
 | 46 | + | 
 | 47 | +struct MTKHomotopySystem{F, P, J, V} <: HomotopyContinuation.AbstractSystem  | 
 | 48 | +    f::F  | 
 | 49 | +    p::P  | 
 | 50 | +    jac::J  | 
 | 51 | +    vars::V  | 
 | 52 | +    nexprs::Int  | 
 | 53 | +end  | 
 | 54 | + | 
 | 55 | +Base.size(sys::MTKHomotopySystem) = (sys.nexprs, length(sys.vars))  | 
 | 56 | +ModelKit.variables(sys::MTKHomotopySystem) = sys.vars  | 
 | 57 | + | 
 | 58 | +function (sys::MTKHomotopySystem)(x, p = nothing)  | 
 | 59 | +    sys.f(x, sys.p)  | 
 | 60 | +end  | 
 | 61 | + | 
 | 62 | +function ModelKit.evaluate!(u, sys::MTKHomotopySystem, x, p = nothing)  | 
 | 63 | +    sys.f(u, x, sys.p)  | 
 | 64 | +end  | 
 | 65 | + | 
 | 66 | +function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = nothing)  | 
 | 67 | +    sys.f(u, x, sys.p)  | 
 | 68 | +    sys.jac(U, x, sys.p)  | 
 | 69 | +end  | 
 | 70 | + | 
 | 71 | +function ModelingToolkit.HomotopyContinuationProblem(  | 
 | 72 | +        sys::NonlinearSystem, u0map, parammap; compile = :all, kwargs...)  | 
 | 73 | +    if !iscomplete(sys)  | 
 | 74 | +        error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")  | 
 | 75 | +    end  | 
 | 76 | + | 
 | 77 | +    dvs = unknowns(sys)  | 
 | 78 | +    ps = parameters(sys)  | 
 | 79 | +    eqs = equations(sys)  | 
 | 80 | + | 
 | 81 | +    for eq in eqs  | 
 | 82 | +        if !is_polynomial(eq.lhs, dvs) || !is_polynomial(eq.rhs, dvs)  | 
 | 83 | +            error("Equation $eq is not a polynomial in the unknowns")  | 
 | 84 | +        end  | 
 | 85 | +    end  | 
 | 86 | + | 
 | 87 | +    nlfn = NonlinearFunction(sys; jac = true)  | 
 | 88 | +    hvars = symbolics_to_hc.(dvs)  | 
 | 89 | + | 
 | 90 | +    if has_index_cache(sys) && get_index_cache(sys) !== nothing  | 
 | 91 | +        u0, defs = get_u0(sys, u0map, parammap)  | 
 | 92 | +        check_eqs_u0(eqs, dvs, u0; kwargs...)  | 
 | 93 | +        p = MTKParameters(sys, parammap, u0map)  | 
 | 94 | +    else  | 
 | 95 | +        u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)  | 
 | 96 | +        check_eqs_u0(eqs, dvs, u0; kwargs...)  | 
 | 97 | +    end  | 
 | 98 | + | 
 | 99 | +    mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))  | 
 | 100 | + | 
 | 101 | +    prob = ModelingToolkit.HomotopyContinuationProblem{typeof(mtkhsys), typeof(u0)}(  | 
 | 102 | +        sys, mtkhsys, u0)  | 
 | 103 | +end  | 
 | 104 | + | 
 | 105 | +function CommonSolve.solve(prob::ModelingToolkit.HomotopyContinuationProblem; kwargs...)  | 
 | 106 | +    sol = HomotopyContinuation.solve(prob.hcsys; kwargs...)  | 
 | 107 | +    rsols = HomotopyContinuation.real_solutions(sol)  | 
 | 108 | +    rsol = findmin(rsols) do val  | 
 | 109 | +        norm(prob.u0 - val)  | 
 | 110 | +    end  | 
 | 111 | +    return rsol  | 
 | 112 | +end  | 
 | 113 | + | 
 | 114 | +end  | 
0 commit comments