|  | 
|  | 1 | +function generate_ode_nlprobdata(sys::System, u0, p, mm = calculate_massmatrix(sys)) | 
|  | 2 | +    nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm) | 
|  | 3 | +    state = ProblemState(; u = u0, p) | 
|  | 4 | +    op = Dict() | 
|  | 5 | +    op[ODE_GAMMA] = one(eltype(u0)) | 
|  | 6 | +    op[ODE_C] = zero(eltype(u0)) | 
|  | 7 | +    op[outer_tmp] = zeros(eltype(u0), size(outer_tmp)) | 
|  | 8 | +    op[inner_tmp] = zeros(eltype(u0), size(inner_tmp)) | 
|  | 9 | +    for v in [unknowns(nlsys); parameters(nlsys)] | 
|  | 10 | +        haskey(op, v) && continue | 
|  | 11 | +        op[v] = getsym(sys, v)(state) | 
|  | 12 | +    end | 
|  | 13 | +    nlprob = NonlinearProblem(nlsys, op; build_initializeprob = false) | 
|  | 14 | +    set_gamma_c = setsym(nlsys, (ODE_GAMMA, ODE_C)) | 
|  | 15 | +    set_outer_tmp = setsym(nlsys, outer_tmp) | 
|  | 16 | +    set_inner_tmp = setsym(nlsys, inner_tmp) | 
|  | 17 | +    nlprobmap = getsym(nlsys, unknowns(sys)) | 
|  | 18 | + | 
|  | 19 | +    return SciMLBase.ODE_NLProbData(nlprob, nothing, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap) | 
|  | 20 | +end | 
|  | 21 | + | 
|  | 22 | +const ODE_GAMMA = only(@parameters γₘₜₖ) | 
|  | 23 | +const ODE_C = only(@parameters cₘₜₖ) | 
|  | 24 | + | 
|  | 25 | +function get_outer_tmp(n::Int) | 
|  | 26 | +    only(@parameters outer_tmpₘₜₖ[1:n]) | 
|  | 27 | +end | 
|  | 28 | + | 
|  | 29 | +function get_inner_tmp(n::Int) | 
|  | 30 | +    only(@parameters inner_tmpₘₜₖ[1:n]) | 
|  | 31 | +end | 
|  | 32 | + | 
|  | 33 | +function inner_nlsystem(sys::System, mm) | 
|  | 34 | +    dvs = unknowns(sys) | 
|  | 35 | +    eqs = full_equations(sys) | 
|  | 36 | +    t = get_iv(sys) | 
|  | 37 | +    N = length(dvs) | 
|  | 38 | +    @assert length(eqs) == N | 
|  | 39 | +    @assert mm == I || size(mm) == (N, N) | 
|  | 40 | +    rhss = [eq.rhs for eq in eqs] | 
|  | 41 | +    gamma = ODE_GAMMA | 
|  | 42 | +    c = ODE_C | 
|  | 43 | +    outer_tmp = get_outer_tmp(N) | 
|  | 44 | +    inner_tmp = get_inner_tmp(N) | 
|  | 45 | + | 
|  | 46 | +    subrules = Dict([v => v + inner_tmp[i] for (i, v) in enumerate(dvs)]) | 
|  | 47 | +    subrules[t] = t + c | 
|  | 48 | +    new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss) | 
|  | 49 | +    new_rhss = mm * dvs - gamma .* new_rhss .+ collect(outer_tmp) | 
|  | 50 | +    new_eqs = [0 ~ rhs for rhs in new_rhss] | 
|  | 51 | + | 
|  | 52 | +    new_dvs = unknowns(sys) | 
|  | 53 | +    new_ps = [parameters(sys); [gamma, c, inner_tmp, outer_tmp]] | 
|  | 54 | +    nlsys = mtkcompile(System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys)) | 
|  | 55 | +    return nlsys, outer_tmp, inner_tmp | 
|  | 56 | +end | 
0 commit comments