Skip to content

Commit 8575674

Browse files
feat: initial implementation of SCCNonlinearProblem codegen
1 parent ac38df6 commit 8575674

File tree

6 files changed

+251
-148
lines changed

6 files changed

+251
-148
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ REPL = "1"
123123
RecursiveArrayTools = "3.26"
124124
Reexport = "0.2, 1"
125125
RuntimeGeneratedFunctions = "0.5.9"
126-
SciMLBase = "2.57.1"
126+
SciMLBase = "2.61"
127127
SciMLStructures = "1.0"
128128
Serialization = "1"
129129
Setfield = "0.7, 0.8, 1"

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars
6262
return nothing
6363
end
6464

65+
function build_var_eq_matching(structure::SystemStructure, ::Type{U} = Unassigned;
66+
varfilter::F2 = v -> true, eqfilter::F3 = eq -> true) where {U, F2, F3}
67+
@unpack graph, solvable_graph = structure
68+
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
69+
matching_len = max(length(var_eq_matching),
70+
maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))
71+
return complete(var_eq_matching, matching_len), matching_len
72+
end
73+
6574
function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
6675
::Type{U} = Unassigned;
6776
varfilter::F2 = v -> true,
@@ -78,10 +87,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
7887
# find them here [TODO: It would be good to have an explicit example of this.]
7988

8089
@unpack graph, solvable_graph = structure
81-
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
82-
matching_len = max(length(var_eq_matching),
83-
maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))
84-
var_eq_matching = complete(var_eq_matching, matching_len)
90+
var_eq_matching, matching_len = build_var_eq_matching(structure, U; varfilter, eqfilter)
8591
full_var_eq_matching = copy(var_eq_matching)
8692
var_sccs = find_var_sccs(graph, var_eq_matching)
8793
vargraph = DiCMOBiGraph{true}(graph, 0, Matching(matching_len))

src/systems/abstractsystem.jl

Lines changed: 56 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ object.
162162
"""
163163
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
164164
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing,
165-
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
165+
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__,
166+
cachesyms::Tuple = (), kwargs...)
166167
if !iscomplete(sys)
167168
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
168169
end
169-
p = reorder_parameters(sys, unwrap.(ps))
170+
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
170171
isscalar = !(exprs isa AbstractArray)
171172
if wrap_code === nothing
172173
wrap_code = isscalar ? identity : (identity, identity)
@@ -187,7 +188,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
187188
postprocess_fbody,
188189
states,
189190
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
190-
wrap_array_vars(sys, exprs; dvs) .∘
191+
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
191192
wrap_parameter_dependencies(sys, isscalar),
192193
expression = Val{true}
193194
)
@@ -199,7 +200,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
199200
postprocess_fbody,
200201
states,
201202
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
202-
wrap_array_vars(sys, exprs; dvs) .∘
203+
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
203204
wrap_parameter_dependencies(sys, isscalar),
204205
expression = Val{true}
205206
)
@@ -231,116 +232,51 @@ end
231232

232233
function wrap_array_vars(
233234
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
234-
inputs = nothing, history = false)
235+
inputs = nothing, history = false, cachesyms::Tuple = ())
235236
isscalar = !(exprs isa AbstractArray)
236-
array_vars = Dict{Any, AbstractArray{Int}}()
237-
if dvs !== nothing
238-
for (j, x) in enumerate(dvs)
239-
if iscall(x) && operation(x) == getindex
240-
arg = arguments(x)[1]
241-
inds = get!(() -> Int[], array_vars, arg)
242-
push!(inds, j)
243-
end
244-
end
245-
for (k, inds) in array_vars
246-
if inds == (inds′ = inds[1]:inds[end])
247-
array_vars[k] = inds′
248-
end
249-
end
237+
var_to_arridxs = Dict()
250238

251-
uind = 1
252-
else
239+
if dvs === nothing
253240
uind = 0
254-
end
255-
# values are (indexes, index of buffer, size of parameter)
256-
array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
257-
# If for some reason different elements of an array parameter are in different buffers
258-
other_array_parameters = Dict{Any, Any}()
259-
260-
hasinputs = inputs !== nothing
261-
input_vars = Dict{Any, AbstractArray{Int}}()
262-
if hasinputs
263-
for (j, x) in enumerate(inputs)
264-
if iscall(x) && operation(x) == getindex
265-
arg = arguments(x)[1]
266-
inds = get!(() -> Int[], input_vars, arg)
267-
push!(inds, j)
268-
end
269-
end
270-
for (k, inds) in input_vars
271-
if inds == (inds′ = inds[1]:inds[end])
272-
input_vars[k] = inds′
273-
end
274-
end
275-
end
276-
if has_index_cache(sys)
277-
ic = get_index_cache(sys)
278241
else
279-
ic = nothing
280-
end
281-
if ps isa Tuple && eltype(ps) <: AbstractArray
282-
ps = Iterators.flatten(ps)
283-
end
284-
for p in ps
285-
p = unwrap(p)
286-
if iscall(p) && operation(p) == getindex
287-
p = arguments(p)[1]
288-
end
289-
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
290-
scal = collect(p)
291-
# all scalarized variables are in `ps`
292-
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
293-
(haskey(array_parameters, p) || haskey(other_array_parameters, p)) && continue
294-
295-
idx = parameter_index(sys, p)
296-
idx isa Int && continue
297-
if idx isa ParameterIndex
298-
if idx.portion != SciMLStructures.Tunable()
242+
uind = 1
243+
for (i, x) in enumerate(dvs)
244+
iscall(x) && operation(x) == getindex || continue
245+
arg = arguments(x)[1]
246+
inds = get!(() -> [], var_to_arridxs, arg)
247+
push!(inds, (uind, i))
248+
end
249+
end
250+
p_start = uind + 1 + (inputs !== nothing) + history
251+
input_ind = inputs === nothing ? -1 : (p_start - 1)
252+
rps = (reorder_parameters(sys, ps)..., cachesyms...)
253+
for sym in reduce(vcat, rps; init = [])
254+
iscall(sym) && operation(sym) == getindex || continue
255+
arg = arguments(sym)[1]
256+
if inputs !== nothing
257+
idx = findfirst(isequal(sym), inputs)
258+
if idx !== nothing
259+
inds = get!(() -> [], var_to_arridxs, arg)
260+
push!(inds, (input_ind, idx))
299261
continue
300262
end
301-
array_parameters[p] = (vec(idx.idx), 1, size(idx.idx))
302-
else
303-
# idx === nothing
304-
idxs = map(Base.Fix1(parameter_index, sys), scal)
305-
if first(idxs) isa ParameterIndex
306-
buffer_idxs = map(Base.Fix1(iterated_buffer_index, ic), idxs)
307-
if allequal(buffer_idxs)
308-
buffer_idx = first(buffer_idxs)
309-
if first(idxs).portion == SciMLStructures.Tunable()
310-
idxs = map(x -> x.idx, idxs)
311-
else
312-
idxs = map(x -> x.idx[end], idxs)
313-
end
314-
else
315-
other_array_parameters[p] = scal
316-
continue
317-
end
318-
else
319-
buffer_idx = 1
320-
end
321-
322-
sz = size(idxs)
323-
if vec(idxs) == idxs[begin]:idxs[end]
324-
idxs = idxs[begin]:idxs[end]
325-
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
326-
idxs = idxs[begin]:-1:idxs[end]
327-
end
328-
idxs = vec(idxs)
329-
array_parameters[p] = (idxs, buffer_idx, sz)
330263
end
264+
bufferidx = findfirst(buf -> any(isequal(sym), buf), rps)
265+
idxinbuffer = findfirst(isequal(sym), rps[bufferidx])
266+
inds = get!(() -> [], var_to_arridxs, arg)
267+
push!(inds, (p_start + bufferidx - 1, idxinbuffer))
331268
end
332269

333-
inputind = if history
334-
uind + 2
335-
else
336-
uind + 1
337-
end
338-
params_offset = if history && hasinputs
339-
uind + 2
340-
elseif history || hasinputs
341-
uind + 1
342-
else
343-
uind
270+
viewsyms = Dict()
271+
splitsyms = Dict()
272+
for (arrsym, idxs) in var_to_arridxs
273+
length(idxs) == length(arrsym) || continue
274+
# allequal(first, idxs) is a 1.11 feature
275+
if allequal(Iterators.map(first, idxs))
276+
viewsyms[arrsym] = (first(first(idxs)), reshape(last.(idxs), size(arrsym)))
277+
else
278+
splitsyms[arrsym] = reshape(idxs, size(arrsym))
279+
end
344280
end
345281
if isscalar
346282
function (expr)
@@ -349,15 +285,11 @@ function wrap_array_vars(
349285
[],
350286
Let(
351287
vcat(
352-
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
353-
[k :(view($(expr.args[inputind].name), $v))
354-
for (k, v) in input_vars],
355-
[k :(reshape(
356-
view($(expr.args[params_offset + buffer_idx].name), $idxs),
357-
$sz))
358-
for (k, (idxs, buffer_idx, sz)) in array_parameters],
359-
[k Code.MakeArray(v, symtype(k))
360-
for (k, v) in other_array_parameters]
288+
[sym :(view($(expr.args[i].name), $idxs))
289+
for (sym, (i, idxs)) in viewsyms],
290+
[sym
291+
MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs],
292+
expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms]
361293
),
362294
expr.body,
363295
false
@@ -371,15 +303,11 @@ function wrap_array_vars(
371303
[],
372304
Let(
373305
vcat(
374-
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
375-
[k :(view($(expr.args[inputind].name), $v))
376-
for (k, v) in input_vars],
377-
[k :(reshape(
378-
view($(expr.args[params_offset + buffer_idx].name), $idxs),
379-
$sz))
380-
for (k, (idxs, buffer_idx, sz)) in array_parameters],
381-
[k Code.MakeArray(v, symtype(k))
382-
for (k, v) in other_array_parameters]
306+
[sym :(view($(expr.args[i].name), $idxs))
307+
for (sym, (i, idxs)) in viewsyms],
308+
[sym
309+
MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs],
310+
expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms]
383311
),
384312
expr.body,
385313
false
@@ -392,17 +320,11 @@ function wrap_array_vars(
392320
[],
393321
Let(
394322
vcat(
395-
[k :(view($(expr.args[uind + 1].name), $v))
396-
for (k, v) in array_vars],
397-
[k :(view($(expr.args[inputind + 1].name), $v))
398-
for (k, v) in input_vars],
399-
[k :(reshape(
400-
view($(expr.args[params_offset + buffer_idx + 1].name),
401-
$idxs),
402-
$sz))
403-
for (k, (idxs, buffer_idx, sz)) in array_parameters],
404-
[k Code.MakeArray(v, symtype(k))
405-
for (k, v) in other_array_parameters]
323+
[sym :(view($(expr.args[i + 1].name), $idxs))
324+
for (sym, (i, idxs)) in viewsyms],
325+
[sym MakeArray(
326+
[expr.args[bufi + 1].elems[vali] for (bufi, vali) in idxs],
327+
expr.args[idxs[1][1] + 1]) for (sym, idxs) in splitsyms]
406328
),
407329
expr.body,
408330
false

0 commit comments

Comments
 (0)