Skip to content

Commit d2880c9

Browse files
committed
Merge branch 'master' into kf/state_selection
2 parents 49da181 + 12b2efe commit d2880c9

File tree

10 files changed

+289
-107
lines changed

10 files changed

+289
-107
lines changed

.github/workflows/Downstream.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ on:
77

88
jobs:
99
test:
10-
name: ${{ matrix.package.repo }}/${{ matrix.package.group }}
10+
name: ${{ matrix.package.repo }}/${{ matrix.package.group }}/${{ matrix.julia-version }}
1111
runs-on: ${{ matrix.os }}
1212
env:
1313
GROUP: ${{ matrix.package.group }}
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
julia-version: [1]
17+
julia-version: [1,1.6]
1818
os: [ubuntu-latest]
1919
package:
2020
- {user: SciML, repo: SciMLBase.jl, group: Downstream}

.github/workflows/ci.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@ on:
99
jobs:
1010
test:
1111
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
group:
15+
- All
16+
version:
17+
- '1'
18+
- '1.6'
1219
steps:
1320
- uses: actions/checkout@v2
1421
- uses: julia-actions/setup-julia@v1
1522
with:
16-
version: 1
23+
version: ${{ matrix.version }}
1724
- uses: actions/cache@v1
1825
env:
1926
cache-name: cache-artifacts

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ DocStringExtensions = "0.7, 0.8"
5757
DomainSets = "0.5"
5858
Graphs = "1.4"
5959
IfElse = "0.1"
60-
JuliaFormatter = "0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19"
60+
JuliaFormatter = "0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20"
6161
LabelledArrays = "1.3"
6262
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15"
6363
MacroTools = "0.5"
@@ -72,11 +72,11 @@ SciMLBase = "1.3"
7272
Setfield = "0.7, 0.8"
7373
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7474
StaticArrays = "0.10, 0.11, 0.12, 1.0"
75-
SymbolicUtils = "0.18"
75+
SymbolicUtils = "0.18, 0.19"
7676
Symbolics = "4.0.0"
7777
UnPack = "0.1, 1.0"
7878
Unitful = "1.1"
79-
julia = "1.2"
79+
julia = "1.6"
8080

8181
[extras]
8282
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"

docs/src/mtkitize_tutorials/modelingtoolkitize.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,5 @@ sys = modelingtoolkitize(prob)
2727
Using this, we can symbolically build the Jacobian and then rebuild the ODEProblem:
2828

2929
```julia
30-
jac = eval(ModelingToolkit.generate_jacobian(sys)[2])
31-
f = ODEFunction(rober, jac=jac)
32-
prob_jac = ODEProblem(f,[1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4))
30+
prob_jac = ODEProblem(sys,[],(0.0,1e5),jac=true)
3331
```

src/structural_transformation/codegen.jl

Lines changed: 93 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using LinearAlgebra
22

3+
using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfinding_callback, generate_difference_cb, merge_cb
4+
35
const MAX_INLINE_NLSOLVE_SIZE = 8
46

5-
function torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs)
7+
function torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
68
fullvars = state.fullvars
79
graph = state.structure.graph
810

@@ -40,55 +42,53 @@ function torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs)
4042
# from previous partitions. Hence, we can build the dependency chain as we
4143
# traverse the partitions.
4244

43-
# `avars2dvars` maps a algebraic variable to its differential variable
44-
# dependencies.
45-
avars2dvars = Dict{Int,Set{Int}}()
46-
c = 0
47-
for scc in var_sccs
48-
v_residual = scc
49-
e_residual = [var_eq_matching[c] for c in v_residual if var_eq_matching[c] !== unassigned]
50-
# initialization
51-
for tvar in v_residual
52-
avars2dvars[tvar] = Set{Int}()
45+
var_rename = ones(Int64, ndsts(graph))
46+
nlsolve_vars = Int[]
47+
for i in nlsolve_scc_idxs, c in var_sccs[i]
48+
append!(nlsolve_vars, c)
49+
for v in c
50+
var_rename[v] = 0
5351
end
54-
for teq in e_residual
55-
c += 1
56-
for var in 𝑠neighbors(graph, teq)
57-
# Skip the tearing variables in the current partition, because
58-
# we are computing them from all the other states.
59-
Graphs.insorted(var, v_residual) && continue
60-
deps = get(avars2dvars, var, nothing)
61-
if deps === nothing # differential variable
62-
@assert !isalgvar(state.structure, var)
63-
for tvar in v_residual
64-
push!(avars2dvars[tvar], var)
65-
end
66-
else # tearing variable from previous partitions
67-
@assert isalgvar(state.structure, var)
68-
for tvar in v_residual
69-
union!(avars2dvars[tvar], avars2dvars[var])
70-
end
71-
end
52+
end
53+
masked_cumsum!(var_rename)
54+
55+
dig = DiCMOBiGraph{true}(graph, var_eq_matching)
56+
57+
fused_var_deps = map(1:ndsts(graph)) do v
58+
BitSet(v′ for v′ in neighborhood(dig, v, Inf; dir=:in) if var_rename[v′] != 0)
59+
end
60+
61+
for scc in var_sccs[nlsolve_scc_idxs]
62+
if length(scc) >= 2
63+
deps = fused_var_deps[scc[1]]
64+
for c in 2:length(scc)
65+
union!(deps, fused_var_deps[c])
66+
fused_var_deps[c] = deps
7267
end
7368
end
7469
end
7570

76-
dvrange = diffvars_range(state.structure)
77-
dvar2idx = Dict(v=>i for (i, v) in enumerate(dvrange))
71+
var2idx = Dict{Int,Int}(v => i for (i, v) in enumerate(states_idxs))
72+
eqs2idx = Dict{Int,Int}(v => i for (i, v) in enumerate(eqs_idxs))
73+
nlsolve_vars_set = BitSet(nlsolve_vars)
74+
7875
I = Int[]; J = Int[]
79-
eqidx = 0
80-
aeqs = algeqs(state.structure)
76+
s = state.structure
8177
for ieq in 𝑠vertices(graph)
82-
ieq in aeqs && continue
83-
eqidx += 1
78+
nieq = get(eqs2idx, ieq, 0)
79+
nieq == 0 && continue
8480
for ivar in 𝑠neighbors(graph, ieq)
85-
if isdiffvar(state.structure, ivar)
86-
push!(I, eqidx)
87-
push!(J, dvar2idx[ivar])
88-
elseif isalgvar(state.structure, ivar)
89-
for dvar in avars2dvars[ivar]
90-
push!(I, eqidx)
91-
push!(J, dvar2idx[dvar])
81+
isdervar(s, ivar) && continue
82+
if var_rename[ivar] != 0
83+
push!(I, nieq)
84+
push!(J, var2idx[ivar])
85+
else
86+
for dvar in fused_var_deps[ivar]
87+
isdervar(s, dvar) && continue
88+
niv = get(var2idx, dvar, 0)
89+
niv == 0 && continue
90+
push!(I, nieq)
91+
push!(J, niv)
9292
end
9393
end
9494
end
@@ -122,7 +122,17 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
122122
params = setdiff(allvars, vars) # these are not the subject of the root finding
123123

124124
# splatting to tighten the type
125-
u0 = [map(var->get(u0map, var, 1e-3), vars)...]
125+
u0 = []
126+
for v in vars
127+
v in keys(u0map) || (push!(u0, 1e-3); continue)
128+
u = substitute(v, u0map)
129+
for i in 1:length(u0map)
130+
u = substitute(u, u0map)
131+
u isa Number && (push!(u0, u); break)
132+
end
133+
u isa Number || error("$v doesn't have a default.")
134+
end
135+
u0 = [u0...]
126136
# specialize on the scalar case
127137
isscalar = length(u0) == 1
128138
u0 = isscalar ? u0[1] : SVector(u0...)
@@ -167,8 +177,11 @@ function build_torn_function(
167177
max_inlining_size = something(max_inlining_size, MAX_INLINE_NLSOLVE_SIZE)
168178
rhss = []
169179
eqs = equations(sys)
170-
for eq in eqs
171-
isdiffeq(eq) && push!(rhss, eq.rhs)
180+
eqs_idxs = Int[]
181+
for (i, eq) in enumerate(eqs)
182+
isdiffeq(eq) || continue
183+
push!(eqs_idxs, i)
184+
push!(rhss, eq.rhs)
172185
end
173186

174187
state = TearingState(sys)
@@ -179,23 +192,26 @@ function build_torn_function(
179192
toporder = topological_sort_by_dfs(condensed_graph)
180193
var_sccs = var_sccs[toporder]
181194

182-
states = map(i->fullvars[i], diffvars_range(state.structure))
183-
mass_matrix_diag = ones(length(states))
195+
states_idxs = collect(diffvars_range(state.structure))
196+
mass_matrix_diag = ones(length(states_idxs))
184197
torn_expr = []
185198
defs = defaults(sys)
199+
nlsolve_scc_idxs = Int[]
186200

187201
needs_extending = false
188-
for scc in var_sccs
189-
torn_vars = [fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
190-
torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] !== unassigned]
191-
isempty(torn_eqs) && continue
192-
if length(torn_eqs) <= max_inlining_size
193-
append!(torn_expr, gen_nlsolve(torn_eqs, torn_vars, defs, checkbounds=checkbounds))
202+
for (i, scc) in enumerate(var_sccs)
203+
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] !== unassigned]
204+
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205+
isempty(torn_eqs_idxs) && continue
206+
if length(torn_eqs_idxs) <= max_inlining_size
207+
append!(torn_expr, gen_nlsolve(eqs[torn_eqs_idxs], fullvars[torn_vars_idxs], defs, checkbounds=checkbounds))
208+
push!(nlsolve_scc_idxs, i)
194209
else
195210
needs_extending = true
196-
append!(rhss, map(x->x.rhs, torn_eqs))
197-
append!(states, torn_vars)
198-
append!(mass_matrix_diag, zeros(length(torn_eqs)))
211+
append!(eqs_idxs, torn_eqs_idxs)
212+
append!(rhss, map(x->x.rhs, eqs[torn_eqs_idxs]))
213+
append!(states_idxs, torn_vars_idxs)
214+
append!(mass_matrix_diag, zeros(length(torn_eqs_idxs)))
199215
end
200216
end
201217

@@ -208,7 +224,8 @@ function build_torn_function(
208224
rhss
209225
)
210226

211-
syms = map(Symbol, states)
227+
states = fullvars[states_idxs]
228+
syms = map(Symbol, states_idxs)
212229
pre = get_postprocess_fbody(sys)
213230

214231
expr = SymbolicUtils.Code.toexpr(
@@ -240,7 +257,7 @@ function build_torn_function(
240257

241258
ODEFunction{true}(
242259
@RuntimeGeneratedFunction(expr),
243-
sparsity = torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs),
260+
sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
244261
syms = syms,
245262
observed = observedfun,
246263
mass_matrix = mass_matrix,
@@ -375,14 +392,29 @@ function ODAEProblem{iip}(
375392
u0map,
376393
tspan,
377394
parammap=DiffEqBase.NullParameters();
378-
kw...
395+
callback = nothing,
396+
kwargs...
379397
) where {iip}
380-
fun, dvs = build_torn_function(sys; kw...)
398+
fun, dvs = build_torn_function(sys; kwargs...)
381399
ps = parameters(sys)
382400
defs = defaults(sys)
383401

384402
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs)
385403
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs)
386404

387-
ODEProblem{iip}(fun, u0, tspan, p; kw...)
405+
has_difference = any(isdifferenceeq, equations(sys))
406+
if has_continuous_events(sys)
407+
event_cb = generate_rootfinding_callback(sys; kwargs...)
408+
else
409+
event_cb = nothing
410+
end
411+
difference_cb = has_difference ? generate_difference_cb(sys; kwargs...) : nothing
412+
cb = merge_cb(event_cb, difference_cb)
413+
cb = merge_cb(cb, callback)
414+
415+
if cb === nothing
416+
ODEProblem{iip}(fun, u0, tspan, p; kwargs...)
417+
else
418+
ODEProblem{iip}(fun, u0, tspan, p; callback=cb, kwargs...)
419+
end
388420
end

src/structural_transformation/tearing.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,24 @@ function contract_variables(graph::BipartiteGraph, var_eq_matching::Matching, el
3838
[var_rename[v′] for v′ in neighborhood(dig, v, Inf; dir=:in) if var_rename[v′] != 0]
3939
end
4040

41-
new_fadjlist = Vector{Int}[
42-
let new_list = Vector{Int}()
43-
for v in graph.fadjlist[i]
44-
if var_rename[v] != 0
45-
push!(new_list, var_rename[v])
46-
else
47-
append!(new_list, var_deps[v])
41+
nelim = length(eliminated_variables)
42+
newgraph = BipartiteGraph(nsrcs(graph) - nelim, ndsts(graph) - nelim)
43+
for e in 𝑠vertices(graph)
44+
ne = eq_rename[e]
45+
ne == 0 && continue
46+
for v in 𝑠neighbors(graph, e)
47+
newvar = var_rename[v]
48+
if newvar != 0
49+
add_edge!(newgraph, ne, newvar)
50+
else
51+
for nv in var_deps[v]
52+
add_edge!(newgraph, ne, nv)
4853
end
4954
end
50-
new_list
51-
end for i = 1:nsrcs(graph) if eq_rename[i] != 0]
55+
end
56+
end
5257

53-
return BipartiteGraph(new_fadjlist, ndsts(graph) - length(eliminated_variables))
58+
return newgraph
5459
end
5560

5661
"""

0 commit comments

Comments
 (0)