Skip to content

Commit 73e9746

Browse files
committed
Fix ODAEProblem
1 parent 036f74f commit 73e9746

File tree

4 files changed

+70
-57
lines changed

4 files changed

+70
-57
lines changed

src/structural_transformation/codegen.jl

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -95,41 +95,26 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
9595
sparse(I, J, true)
9696
end
9797

98-
"""
99-
exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true, assignments)
100-
101-
Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
102-
as well as a call to the root-finding solver.
103-
104-
`exprs` is a two element vector
105-
```
106-
exprs = [fname = f, numerical_nlsolve(fname, ...)]
107-
```
108-
109-
# Arguments:
110-
- `eqs`: Equations to find roots of.
111-
- `vars`: ???
112-
- `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
113-
- `checkbounds`: Apply bounds checking in the generated code.
114-
"""
115-
function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true, assignments)
98+
function gen_nlsolve(eqs, vars, u0map::AbstractDict, assignments, deps, var2assignment; checkbounds=true)
11699
isempty(vars) && throw(ArgumentError("vars may not be empty"))
117100
length(eqs) == length(vars) || throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
118101
rhss = map(x->x.rhs, eqs)
119102
# We use `vars` instead of `graph` to capture parameters, too.
120-
allvars = Set(Iterators.flatten(ModelingToolkit.vars(r) for r in rhss))
121-
vars_set = Set(vars)
122-
params = setdiff(allvars, vars_set) # these are not the subject of the root finding
123-
#needed_assignments = filter(a->a.lhs in params, assignments)
124-
needed_assignments = assignments[1:findlast(a->a.lhs in params, assignments)]
125-
params = setdiff(params, [a.lhs for a in needed_assignments])
126-
@show needed_assignments, params
127-
for a in needed_assignments
128-
ModelingToolkit.vars!(params, a.rhs)
103+
paramset = Set(Iterators.flatten(ModelingToolkit.vars(r) for r in rhss))
104+
105+
init_assignments = [var2assignment[p] for p in paramset if haskey(var2assignment, p)]
106+
tmp = [init_assignments]
107+
# `deps[init_assignments]` gives the dependency of `init_assignments`
108+
while (next_assignments = reduce(vcat, deps[init_assignments]); !isempty(next_assignments))
109+
init_assignments = next_assignments
110+
push!(tmp, init_assignments)
129111
end
130-
params = setdiff(params, vars_set) # these are not the subject of the root finding
131-
@show params
132-
# inductor1₊v, inductor2₊v
112+
needed_assignments = mapreduce(i->assignments[i], vcat, reverse(tmp))
113+
extravars = Set(Iterators.flatten(ModelingToolkit.vars(r.rhs) for r in needed_assignments))
114+
union!(paramset, extravars)
115+
# these are not the subject of the root finding
116+
setdiff!(paramset, vars); setdiff!(paramset, map(a->a.lhs, needed_assignments))
117+
params = collect(paramset)
133118

134119
# splatting to tighten the type
135120
u0 = []
@@ -152,10 +137,13 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true, assignmen
152137
f = Func(
153138
[
154139
DestructuredArgs(vars, inbounds=!checkbounds)
155-
DestructuredArgs(collect(params), inbounds=!checkbounds)
140+
DestructuredArgs(params, inbounds=!checkbounds)
156141
],
157142
[],
158-
Let(needed_assignments, isscalar ? rhss[1] : MakeArray(rhss, SVector))
143+
Let(
144+
needed_assignments,
145+
isscalar ? rhss[1] : MakeArray(rhss, SVector)
146+
)
159147
) |> SymbolicUtils.Code.toexpr
160148

161149
# solver call contains code to call the root-finding solver on the function f
@@ -169,10 +157,12 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true, assignmen
169157
)
170158
end)
171159

172-
[
173-
fname @RuntimeGeneratedFunction(f)
174-
DestructuredArgs(vars, inbounds=!checkbounds) solver_call
175-
]
160+
nlsolve_expr = Assignment[
161+
fname @RuntimeGeneratedFunction(f)
162+
DestructuredArgs(vars, inbounds=!checkbounds) solver_call
163+
]
164+
165+
nlsolve_expr
176166
end
177167

178168
function build_torn_function(
@@ -205,20 +195,33 @@ function build_torn_function(
205195
states_idxs = collect(diffvars_range(s))
206196
mass_matrix_diag = ones(length(states_idxs))
207197

208-
assignments, bf_states = tearing_assignments(sys)
209-
torn_expr = []
198+
assignments, deps, bf_states = tearing_assignments(sys)
199+
var2assignment = Dict{Any,Int}(eq.lhs => i for (i, eq) in enumerate(assignments))
200+
201+
torn_expr = Assignment[]
210202

211203
defs = defaults(sys)
212204
nlsolve_scc_idxs = Int[]
213205

214206
needs_extending = false
215207
@views for (i, scc) in enumerate(var_sccs)
216-
#torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
217208
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] !== unassigned]
218209
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
219210
isempty(torn_eqs_idxs) && continue
220211
if length(torn_eqs_idxs) <= max_inlining_size
221-
append!(torn_expr, gen_nlsolve(eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds, assignments=assignments))
212+
nlsolve_expr = gen_nlsolve(eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds=checkbounds)
213+
#=
214+
# a temporary vector that we need to reverse to get the correct
215+
# dependency evaluation order.
216+
local_deps = Vector{Int}[]
217+
init_deps = [var2assignment[p] for p in params if haskey(var2assignment, p)]
218+
push!(local_deps, init_deps)
219+
while (next_deps = reduce(vcat, deps[init_deps]); !isempty(next_deps))
220+
init_deps = next_deps
221+
push!(local_deps, init_deps)
222+
end
223+
=#
224+
append!(torn_expr, nlsolve_expr)
222225
push!(nlsolve_scc_idxs, i)
223226
else
224227
needs_extending = true
@@ -262,13 +265,13 @@ function build_torn_function(
262265
if expression
263266
expr, states
264267
else
265-
observedfun = let sys=sys, dict=Dict(), assignments=assignments, bf_states=bf_states
268+
observedfun = let sys=sys, dict=Dict(), assignments=assignments, deps=deps, bf_states=bf_states, var2assignment=var2assignment
266269
function generated_observed(obsvar, u, p, t)
267270
obs = get!(dict, value(obsvar)) do
268271
build_observed_function(sys, obsvar, var_eq_matching, var_sccs,
272+
assignments, deps, bf_states, var2assignment,
269273
checkbounds=checkbounds,
270-
assignments=assignments,
271-
bf_states=bf_states)
274+
)
272275
end
273276
obs(u, p, t)
274277
end
@@ -302,12 +305,14 @@ function find_solve_sequence(sccs, vars)
302305
end
303306

304307
function build_observed_function(
305-
sys, ts, var_eq_matching, var_sccs;
308+
sys, ts, var_eq_matching, var_sccs,
309+
assignments,
310+
deps,
311+
bf_states,
312+
var2assignment;
306313
expression=false,
307314
output_type=Array,
308315
checkbounds=true,
309-
assignments,
310-
bf_states,
311316
)
312317

313318
if (isscalar = !(ts isa AbstractVector))
@@ -356,7 +361,8 @@ function build_observed_function(
356361
torn_eqs = map(i->map(v->eqs[var_eq_matching[v]], var_sccs[i]), subset)
357362
torn_vars = map(i->map(v->fullvars[v], var_sccs[i]), subset)
358363
u0map = defaults(sys)
359-
solves = gen_nlsolve.(torn_eqs, torn_vars, (u0map,); checkbounds=checkbounds)
364+
assignments = copy(assignments)
365+
solves = gen_nlsolve.(torn_eqs, torn_vars, (u0map,), (assignments,), (deps,), (var2assignment,); checkbounds=checkbounds)
360366
else
361367
solves = []
362368
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function substitution_graph(graph, slist, dlist, var_eq_matching)
2626
newmatching[iv] = ie
2727
end
2828

29-
return newgraph, newmatching
29+
return DiCMOBiGraph{true}(newgraph, complete(newmatching))
3030
end
3131

3232
function tearing_sub(expr, dict, s)
@@ -36,7 +36,7 @@ end
3636

3737
function tearing_substitution(sys::AbstractSystem; simplify=false)
3838
empty_substitutions(sys) && return sys
39-
subs = get_substitutions(sys)
39+
subs, = get_substitutions(sys)
4040
solved = Dict(eq.lhs => eq.rhs for eq in subs)
4141
neweqs = map(equations(sys)) do eq
4242
if isdiffeq(eq)
@@ -61,13 +61,14 @@ end
6161
function tearing_assignments(sys::AbstractSystem)
6262
if empty_substitutions(sys)
6363
assignments = []
64+
deps = Int[]
6465
bf_states = Code.LazyState()
6566
else
66-
subs = get_substitutions(sys)
67+
subs, deps = get_substitutions(sys)
6768
assignments = [Assignment(eq.lhs, eq.rhs) for eq in subs]
6869
bf_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
6970
end
70-
return assignments, bf_states
71+
return assignments, deps, bf_states
7172
end
7273

7374
function solve_equation(eq, var, simplify)
@@ -102,9 +103,15 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
102103
is_solvable(ieq, iv) || continue
103104
push!(solved_equations, ieq); push!(solved_variables, iv)
104105
end
105-
subgraph, submatching = substitution_graph(graph, solved_equations, solved_variables, var_eq_matching)
106-
toporder = topological_sort_by_dfs(DiCMOBiGraph{true}(subgraph, complete(submatching)))
107-
substitutions = Equation[solve_equation(eqs[solved_equations[i]], fullvars[solved_variables[i]], simplify) for i in toporder]
106+
subgraph = substitution_graph(graph, solved_equations, solved_variables, var_eq_matching)
107+
toporder = topological_sort_by_dfs(subgraph)
108+
substitutions = [solve_equation(
109+
eqs[solved_equations[i]],
110+
fullvars[solved_variables[i]],
111+
simplify
112+
) for i in toporder]
113+
invtoporder = invperm(toporder)
114+
deps = [[invtoporder[n] for n in neighborhood(subgraph, j, Inf, dir=:in) if n!=j] for (i, j) in enumerate(toporder)]
108115

109116
# Rewrite remaining equations in terms of solved variables
110117

@@ -128,7 +135,7 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
128135
@set! sys.eqs = neweqs
129136
@set! sys.states = [s.fullvars[idx] for idx in 1:length(s.fullvars) if !isdervar(s, idx)]
130137
@set! sys.observed = [observed(sys); substitutions]
131-
@set! sys.substitutions = substitutions
138+
@set! sys.substitutions = substitutions, deps
132139
return sys
133140
end
134141

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function generate_function(
103103
bf_states = Code.LazyState()
104104
pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)
105105
else
106-
subs = get_substitutions(sys)
106+
subs, = get_substitutions(sys)
107107
bf_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
108108
if has_difference
109109
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], ex)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,5 +430,5 @@ isarray(x) = x isa AbstractArray || x isa Symbolics.Arr
430430
function empty_substitutions(sys)
431431
has_substitutions(sys) || return true
432432
subs = get_substitutions(sys)
433-
isnothing(subs) || isempty(subs)
433+
isnothing(subs) || isempty(last(subs))
434434
end

0 commit comments

Comments
 (0)