Skip to content

Commit 018cfa0

Browse files
authored
Merge pull request #1414 from SciML/myb/nosub
Optimized lowering of torn systems
2 parents ec795f2 + 79c9bab commit 018cfa0

22 files changed

+416
-159
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ export Differential, expand_derivatives, @derivatives
181181
export Equation, ConstrainedEquation
182182
export Term, Sym
183183
export SymScope, LocalScope, ParentScope, GlobalScope
184-
export independent_variables, independent_variable, states, parameters, equations, controls, observed, structure
184+
export independent_variables, independent_variable, states, parameters, equations, controls, observed, structure, full_equations
185185
export structural_simplify, expand_connections
186186
export DiscreteSystem, DiscreteProblem
187187

src/bipartite_graph.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ function Graphs.incidence_matrix(g::BipartiteGraph, val=true)
397397
S = sparse(I, J, val, nsrcs(g), ndsts(g))
398398
end
399399

400-
401400
"""
402401
struct DiCMOBiGraph
403402

src/structural_transformation/StructuralTransformations.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Dif
1414
states, equations, vars, Symbolic, diff2term, value,
1515
operation, arguments, Sym, Term, simplify, solve_for,
1616
isdiffeq, isdifferential, isinput,
17+
empty_substitutions, get_substitutions,
1718
get_structure, get_iv, independent_variables,
18-
get_structure, defaults, InvalidSystemException,
19+
has_structure, defaults, InvalidSystemException,
1920
ExtraEquationsSystemException,
2021
ExtraVariablesSystemException,
2122
get_postprocess_fbody, vars!,
22-
IncrementalCycleTracker, add_edge_checked!, topological_sort
23+
IncrementalCycleTracker, add_edge_checked!, topological_sort,
24+
invalidate_cache!, Substitutions
2325

2426
using ModelingToolkit.BipartiteGraphs
2527
import .BipartiteGraphs: invview
@@ -37,8 +39,11 @@ using SparseArrays
3739
using NonlinearSolve
3840

3941
export tearing, dae_index_lowering, check_consistency
42+
export tearing_assignments, tearing_substitution
4043
export build_torn_function, build_observed_function, ODAEProblem
4144
export sorted_incidence_matrix
45+
export torn_system_jacobian_sparsity
46+
export full_equations
4247

4348
include("utils.jl")
4449
include("pantelides.jl")

src/structural_transformation/codegen.jl

Lines changed: 127 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

7-
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
7+
function torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
88
s = structure(sys)
99
@unpack fullvars, graph = s
1010

@@ -95,30 +95,71 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
9595
sparse(I, J, true)
9696
end
9797

98-
"""
99-
exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true)
100-
101-
Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
102-
as well as a call to the root-finding solver.
103-
104-
`exprs` is a two element vector
105-
```
106-
exprs = [fname = f, numerical_nlsolve(fname, ...)]
107-
```
108-
109-
# Arguments:
110-
- `eqs`: Equations to find roots of.
111-
- `vars`: ???
112-
- `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
113-
- `checkbounds`: Apply bounds checking in the generated code.
114-
"""
115-
function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
98+
function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDict, assignments, (deps, invdeps), var2assignment; checkbounds=true)
11699
isempty(vars) && throw(ArgumentError("vars may not be empty"))
117100
length(eqs) == length(vars) || throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
118101
rhss = map(x->x.rhs, eqs)
119102
# We use `vars` instead of `graph` to capture parameters, too.
120-
allvars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
121-
params = setdiff(allvars, vars) # these are not the subject of the root finding
103+
paramset = ModelingToolkit.vars(r for r in rhss)
104+
105+
# Compute necessary assignments for the nlsolve expr
106+
init_assignments = [var2assignment[p] for p in paramset if haskey(var2assignment, p)]
107+
tmp = [init_assignments]
108+
# `deps[init_assignments]` gives the dependency of `init_assignments`
109+
while true
110+
next_assignments = reduce(vcat, deps[init_assignments])
111+
isempty(next_assignments) && break
112+
init_assignments = next_assignments
113+
push!(tmp, init_assignments)
114+
end
115+
needed_assignments_idxs = reduce(vcat, unique(reverse(tmp)))
116+
needed_assignments = assignments[needed_assignments_idxs]
117+
118+
# Compute `params`. They are like enclosed variables
119+
rhsvars = [ModelingToolkit.vars(r.rhs) for r in needed_assignments]
120+
vars_set = Set(vars)
121+
outer_set = BitSet()
122+
inner_set = BitSet()
123+
for (i, vs) in enumerate(rhsvars)
124+
j = needed_assignments_idxs[i]
125+
if isdisjoint(vars_set, vs)
126+
push!(outer_set, j)
127+
else
128+
push!(inner_set, j)
129+
end
130+
end
131+
init_refine = BitSet()
132+
for i in inner_set
133+
union!(init_refine, invdeps[i])
134+
end
135+
intersect!(init_refine, outer_set)
136+
setdiff!(outer_set, init_refine)
137+
union!(inner_set, init_refine)
138+
139+
next_refine = BitSet()
140+
while true
141+
for i in init_refine
142+
id = invdeps[i]
143+
isempty(id) && break
144+
union!(next_refine, id)
145+
end
146+
intersect!(next_refine, outer_set)
147+
isempty(next_refine) && break
148+
setdiff!(outer_set, next_refine)
149+
union!(inner_set, next_refine)
150+
151+
init_refine, next_refine = next_refine, init_refine
152+
empty!(next_refine)
153+
end
154+
global2local = Dict(j=>i for (i, j) in enumerate(needed_assignments_idxs))
155+
inner_idxs = [global2local[i] for i in collect(inner_set)]
156+
outer_idxs = [global2local[i] for i in collect(outer_set)]
157+
extravars = reduce(union!, rhsvars[inner_idxs], init=Set())
158+
union!(paramset, extravars)
159+
setdiff!(paramset, vars)
160+
setdiff!(paramset, [needed_assignments[i].lhs for i in inner_idxs])
161+
union!(paramset, [needed_assignments[i].lhs for i in outer_idxs])
162+
params = collect(paramset)
122163

123164
# splatting to tighten the type
124165
u0 = []
@@ -144,7 +185,11 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
144185
DestructuredArgs(params, inbounds=!checkbounds)
145186
],
146187
[],
147-
isscalar ? rhss[1] : MakeArray(rhss, SVector)
188+
Let(
189+
needed_assignments[inner_idxs],
190+
isscalar ? rhss[1] : MakeArray(rhss, SVector),
191+
false
192+
)
148193
) |> SymbolicUtils.Code.toexpr
149194

150195
# solver call contains code to call the root-finding solver on the function f
@@ -158,10 +203,21 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
158203
)
159204
end)
160205

161-
[
162-
fname @RuntimeGeneratedFunction(f)
163-
DestructuredArgs(vars, inbounds=!checkbounds) solver_call
164-
]
206+
preassignments = []
207+
for i in outer_idxs
208+
ii = needed_assignments_idxs[i]
209+
is_not_prepended_assignment[ii] || continue
210+
is_not_prepended_assignment[ii] = false
211+
push!(preassignments, assignments[ii])
212+
end
213+
214+
nlsolve_expr = Assignment[
215+
preassignments
216+
fname @RuntimeGeneratedFunction(f)
217+
DestructuredArgs(vars, inbounds=!checkbounds) solver_call
218+
]
219+
220+
nlsolve_expr
165221
end
166222

167223
function build_torn_function(
@@ -193,18 +249,30 @@ function build_torn_function(
193249

194250
states_idxs = collect(diffvars_range(s))
195251
mass_matrix_diag = ones(length(states_idxs))
196-
torn_expr = []
252+
253+
assignments, deps, sol_states = tearing_assignments(sys)
254+
invdeps = map(_->BitSet(), deps)
255+
for (i, d) in enumerate(deps)
256+
for a in d
257+
push!(invdeps[a], i)
258+
end
259+
end
260+
var2assignment = Dict{Any,Int}(eq.lhs => i for (i, eq) in enumerate(assignments))
261+
is_not_prepended_assignment = trues(length(assignments))
262+
263+
torn_expr = Assignment[]
264+
197265
defs = defaults(sys)
198266
nlsolve_scc_idxs = Int[]
199267

200268
needs_extending = false
201-
for (i, scc) in enumerate(var_sccs)
202-
#torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
269+
@views for (i, scc) in enumerate(var_sccs)
203270
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] !== unassigned]
204271
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205272
isempty(torn_eqs_idxs) && continue
206273
if length(torn_eqs_idxs) <= max_inlining_size
207-
append!(torn_expr, gen_nlsolve(eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, checkbounds=checkbounds))
274+
nlsolve_expr = gen_nlsolve!(is_not_prepended_assignment, eqs[torn_eqs_idxs], s.fullvars[torn_vars_idxs], defs, assignments, (deps, invdeps), var2assignment, checkbounds=checkbounds)
275+
append!(torn_expr, nlsolve_expr)
208276
push!(nlsolve_scc_idxs, i)
209277
else
210278
needs_extending = true
@@ -226,6 +294,7 @@ function build_torn_function(
226294

227295
states = s.fullvars[states_idxs]
228296
syms = map(Symbol, states_idxs)
297+
229298
pre = get_postprocess_fbody(sys)
230299

231300
expr = SymbolicUtils.Code.toexpr(
@@ -238,26 +307,31 @@ function build_torn_function(
238307
],
239308
[],
240309
pre(Let(
241-
torn_expr,
242-
funbody
310+
[torn_expr; assignments[is_not_prepended_assignment]],
311+
funbody,
312+
false
243313
))
244-
)
314+
),
315+
sol_states
245316
)
246317
if expression
247318
expr, states
248319
else
249-
observedfun = let sys = sys, dict = Dict()
320+
observedfun = let sys=sys, dict=Dict(), assignments=assignments, deps=(deps, invdeps), sol_states=sol_states, var2assignment=var2assignment
250321
function generated_observed(obsvar, u, p, t)
251322
obs = get!(dict, value(obsvar)) do
252-
build_observed_function(sys, obsvar, var_eq_matching, var_sccs, checkbounds=checkbounds)
323+
build_observed_function(sys, obsvar, var_eq_matching, var_sccs,
324+
assignments, deps, sol_states, var2assignment,
325+
checkbounds=checkbounds,
326+
)
253327
end
254328
obs(u, p, t)
255329
end
256330
end
257331

258332
ODEFunction{true}(
259333
@RuntimeGeneratedFunction(expr),
260-
sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
334+
sparsity = jacobian_sparsity ? torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
261335
syms = syms,
262336
observed = observedfun,
263337
mass_matrix = mass_matrix,
@@ -283,12 +357,17 @@ function find_solve_sequence(sccs, vars)
283357
end
284358

285359
function build_observed_function(
286-
sys, ts, var_eq_matching, var_sccs;
360+
sys, ts, var_eq_matching, var_sccs,
361+
assignments,
362+
deps,
363+
sol_states,
364+
var2assignment;
287365
expression=false,
288366
output_type=Array,
289-
checkbounds=true
367+
checkbounds=true,
290368
)
291369

370+
is_not_prepended_assignment = trues(length(assignments))
292371
if (isscalar = !(ts isa AbstractVector))
293372
ts = [ts]
294373
end
@@ -335,7 +414,11 @@ function build_observed_function(
335414
torn_eqs = map(i->map(v->eqs[var_eq_matching[v]], var_sccs[i]), subset)
336415
torn_vars = map(i->map(v->fullvars[v], var_sccs[i]), subset)
337416
u0map = defaults(sys)
338-
solves = gen_nlsolve.(torn_eqs, torn_vars, (u0map,); checkbounds=checkbounds)
417+
assignments = copy(assignments)
418+
solves = map(zip(torn_eqs, torn_vars)) do (eqs, vars)
419+
gen_nlsolve!(is_not_prepended_assignment, eqs, vars,
420+
u0map, assignments, deps, var2assignment; checkbounds=checkbounds)
421+
end
339422
else
340423
solves = []
341424
end
@@ -348,7 +431,7 @@ function build_observed_function(
348431
end
349432
pre = get_postprocess_fbody(sys)
350433

351-
ex = Func(
434+
ex = Code.toexpr(Func(
352435
[
353436
DestructuredArgs(diffvars, inbounds=!checkbounds)
354437
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
@@ -360,10 +443,12 @@ function build_observed_function(
360443
collect(Iterators.flatten(solves))
361444
map(eq -> eq.lhseq.rhs, obs[1:maxidx])
362445
subs
446+
assignments[is_not_prepended_assignment]
363447
],
364-
isscalar ? ts[1] : MakeArray(ts, output_type)
448+
isscalar ? ts[1] : MakeArray(ts, output_type),
449+
false
365450
))
366-
) |> Code.toexpr
451+
), sol_states)
367452

368453
expression ? ex : @RuntimeGeneratedFunction(ex)
369454
end

src/structural_transformation/pantelides.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ function dae_index_lowering(sys::ODESystem; kwargs...)
149149
s = get_structure(sys)
150150
(s isa SystemStructure) || (sys = initialize_system_structure(sys))
151151
sys, var_eq_matching, eq_to_diff = pantelides!(sys; kwargs...)
152-
return pantelides_reassemble(sys, eq_to_diff, var_eq_matching)
152+
return invalidate_cache!(pantelides_reassemble(sys, eq_to_diff, var_eq_matching))
153153
end

0 commit comments

Comments
 (0)