Skip to content

Commit 3f6f058

Browse files
feat: generate SciMLBase.ODE_NLProbData
1 parent 4d65458 commit 3f6f058

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ include("problems/docs.jl")
176176
include("systems/codegen.jl")
177177
include("systems/problem_utils.jl")
178178
include("linearization.jl")
179+
include("systems/solver_nlprob.jl")
179180

180181
include("problems/compatibility.jl")
181182
include("problems/odeproblem.jl")

src/systems/solver_nlprob.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
function generate_ode_nlprobdata(sys::System, u0, p, mm = calculate_massmatrix(sys))
2+
nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm)
3+
state = ProblemState(; u = u0, p)
4+
op = Dict()
5+
op[ODE_GAMMA] = one(eltype(u0))
6+
op[ODE_C] = zero(eltype(u0))
7+
op[outer_tmp] = zeros(eltype(u0), size(outer_tmp))
8+
op[inner_tmp] = zeros(eltype(u0), size(inner_tmp))
9+
for v in [unknowns(nlsys); parameters(nlsys)]
10+
haskey(op, v) && continue
11+
op[v] = getsym(sys, v)(state)
12+
end
13+
nlprob = NonlinearProblem(nlsys, op; build_initializeprob = false)
14+
set_gamma_c = setsym(nlsys, (ODE_GAMMA, ODE_C))
15+
set_outer_tmp = setsym(nlsys, outer_tmp)
16+
set_inner_tmp = setsym(nlsys, inner_tmp)
17+
nlprobmap = getsym(nlsys, unknowns(sys))
18+
19+
return SciMLBase.ODE_NLProbData(nlprob, nothing, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
20+
end
21+
22+
const ODE_GAMMA = only(@parameters γₘₜₖ)
23+
const ODE_C = only(@parameters cₘₜₖ)
24+
25+
function get_outer_tmp(n::Int)
26+
only(@parameters outer_tmpₘₜₖ[1:n])
27+
end
28+
29+
function get_inner_tmp(n::Int)
30+
only(@parameters inner_tmpₘₜₖ[1:n])
31+
end
32+
33+
function inner_nlsystem(sys::System, mm)
34+
dvs = unknowns(sys)
35+
eqs = full_equations(sys)
36+
t = get_iv(sys)
37+
N = length(dvs)
38+
@assert length(eqs) == N
39+
@assert mm == I || size(mm) == (N, N)
40+
rhss = [eq.rhs for eq in eqs]
41+
gamma = ODE_GAMMA
42+
c = ODE_C
43+
outer_tmp = get_outer_tmp(N)
44+
inner_tmp = get_inner_tmp(N)
45+
46+
subrules = Dict([v => v + inner_tmp[i] for (i, v) in enumerate(dvs)])
47+
subrules[t] = t + c
48+
new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss)
49+
new_rhss = mm * dvs - gamma .* new_rhss .+ collect(outer_tmp)
50+
new_eqs = [0 ~ rhs for rhs in new_rhss]
51+
52+
new_dvs = unknowns(sys)
53+
new_ps = [parameters(sys); [gamma, c, inner_tmp, outer_tmp]]
54+
nlsys = mtkcompile(System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys))
55+
return nlsys, outer_tmp, inner_tmp
56+
end

0 commit comments

Comments
 (0)