Skip to content

Commit da90444

Browse files
committed
keep observed separate for faster codegen and execution
1 parent b8c865b commit da90444

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3737
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3838
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
3939
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
40+
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
4041
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4142

4243
[extensions]
@@ -75,6 +76,7 @@ StaticArrays = "1.9.4"
7576
SteadyStateDiffEq = "2.2.0"
7677
StyledStrings = "1.0.3"
7778
SymbolicIndexingInterface = "0.3.27"
79+
SymbolicUtils = "3.24"
7880
Symbolics = "6.19.0"
7981
TimerOutputs = "0.5.23"
8082
julia = "1.10"

ext/NetworkDynamicsMTKExt.jl

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ module NetworkDynamicsMTKExt
22

33
using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
44
using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
5-
using ModelingToolkit: full_equations, get_variables, structural_simplify, getname, unwrap
5+
using ModelingToolkit: equations, full_equations, get_variables, structural_simplify, getname, unwrap
66
using ModelingToolkit: parameters, unknowns, independent_variables, observed, defaults
77
using Symbolics: Symbolics, fixpoint_sub, substitute
88
using RecursiveArrayTools: RecursiveArrayTools
99
using ArgCheck: @argcheck
1010
using LinearAlgebra: Diagonal, I
11+
using SymbolicUtils.Code: Let, Assignment
12+
using OrderedCollections: OrderedDict
1113

1214
using NetworkDynamics: NetworkDynamics, set_metadata!,
1315
PureFeedForward, FeedForward, NoFeedForward, PureStateMap
@@ -299,8 +301,10 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
299301
params = setdiff(allparams, Set(allinputs))
300302

301303
# extract the main equations and observed equations
302-
eqs::Vector{Equation} = ModelingToolkit.subs_constants(full_equations(sys))
304+
eqs::Vector{Equation} = ModelingToolkit.subs_constants(equations(sys))
305+
obseqs_sorted::Vector{Equation} = ModelingToolkit.subs_constants(observed(sys))
303306
fix_metadata!(eqs, sys);
307+
fix_metadata!(obseqs_sorted, sys);
304308

305309
# assert the ordering of states and equations
306310
explicit_states = Symbolic[eq_type(eq)[2] for eq in eqs if !isnothing(eq_type(eq)[2])]
@@ -311,8 +315,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
311315
end
312316

313317
# check hat there are no rhs differentials in the equations
314-
if !isempty(rhs_differentials(eqs))
315-
diffs = rhs_differentials(eqs)
318+
if !isempty(rhs_differentials(vcat(eqs, obseqs_sorted)))
319+
diffs = rhs_differentials(vcat(eqs, obseqs_sorted))
316320
buf = IOBuffer()
317321
println(buf, "Equations contain differentials in their rhs: ", diffs)
318322
# for (i, eqs) in enumerate(eqs)
@@ -323,23 +327,17 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
323327
throw(ArgumentError(String(take!(buf))))
324328
end
325329

326-
# extract observed equations. They might depend on eachother so resolve them
327-
obs_subs = Dict(eq.lhs => eq.rhs for eq in observed(sys))
328-
obseqs = map(observed(sys)) do eq
329-
expanded_rhs = fixpoint_sub(eq.rhs, obs_subs)
330-
eq.lhs ~ ModelingToolkit.subs_constants(expanded_rhs)
331-
end
332-
fix_metadata!(obseqs, sys);
333330
# obs can only depend on parameters (including allinputs) or states
334-
obs_deps = _all_rhs_symbols(obseqs)
331+
obs_subs = OrderedDict(eq.lhs => eq.rhs for eq in obseqs_sorted)
332+
obs_deps = _all_rhs_symbols(fixpoint_sub(obseqs_sorted, obs_subs))
335333
if !(obs_deps Set(allparams) Set(states) independent_variables(sys))
336334
@warn "obs_deps !⊆ parameters ∪ unknowns. Difference: $(setdiff(obs_deps, Set(allparams) Set(states)))"
337335
end
338336

339337
# if some states shadow outputs (out ~ state in observed)
340338
# switch their names. I.e. prioritize use of name `out`
341339
renamings = Dict()
342-
for eq in obseqs
340+
for eq in obseqs_sorted
343341
if eq.lhs Set(alloutputs) && iscall(eq.rhs) &&
344342
operation(eq.rhs) isa Symbolics.BasicSymbolic && eq.rhs Set(states)
345343
verbose && @info "Encountered trivial equation $eq. Swap out $(eq.lhs) <=> $(eq.rhs) everywhere."
@@ -349,7 +347,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
349347
end
350348
if !isempty(renamings)
351349
eqs = map(eq -> substitute(eq, renamings), eqs)
352-
obseqs = map(eq -> substitute(eq, renamings), obseqs)
350+
obseqs_sorted = map(eq -> substitute(eq, renamings), obseqs_sorted)
351+
obs_subs = OrderedDict(eq.lhs => eq.rhs for eq in obseqs_sorted)
353352
states = map(s -> substitute(s, renamings), states)
354353
verbose && @info "New States:" states
355354
end
@@ -360,16 +359,16 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
360359
if out Set(states)
361360
push!(outeqs, out ~ out)
362361
else
363-
idx = findfirst(eq -> isequal(eq.lhs, out), obseqs)
362+
idx = findfirst(eq -> isequal(eq.lhs, out), obseqs_sorted)
364363
if isnothing(idx)
365364
throw(ArgumentError("Output $out was neither foundin states nor in observed equations."))
366365
end
367-
eq = obseqs[idx]
366+
eq = obseqs_sorted[idx]
368367
if !isempty(rhs_differentials(eq))
369368
println(obs_subs[out])
370369
throw(ArgumentError("Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials(eq))"))
371370
end
372-
deleteat!(obseqs, idx)
371+
deleteat!(obseqs_sorted, idx)
373372

374373
if ff_to_constraint && !isempty(get_variables(eq.rhs) allinputs)
375374
verbose && @info "Output $out would lead to FF in g, promote to state instead."
@@ -382,6 +381,9 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
382381
end
383382
end
384383

384+
# obseqs might have changed in block above
385+
obs_subs = OrderedDict(eq.lhs => eq.rhs for eq in obseqs_sorted)
386+
385387
# generate mass matrix (this might change the equations)
386388
mass_matrix = begin
387389
# equations of form o = f(...) have to be transformed to 0 = f(...) - o
@@ -399,26 +401,29 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
399401
end
400402

401403
iv = only(independent_variables(sys))
402-
out_deps = _all_rhs_symbols(outeqs)
404+
405+
out_deps = _all_rhs_symbols(fixpoint_sub(outeqs, obs_subs))
403406
fftype = _determine_fftype(out_deps, states, allinputs, params, iv)
404407

405408
# filter out unnecessary parameters
406-
var_deps = _all_rhs_symbols(eqs)
409+
var_deps = _all_rhs_symbols(fixpoint_sub(eqs, obs_subs))
407410
unused_params = Set(setdiff(params, (var_deps out_deps))) # do not exclud obs_deps
408411
if verbose && !isempty(unused_params)
409412
@info "Parameters $(unused_params) do not appear in equations of f and g and will be marked as unused."
410413
end
411414

412415
# TODO: explore Symbolcs/SymbolicUtils CSE
413416
# now generate the actual functions
414-
formulas = [eq.rhs for eq in eqs]
415-
if !isempty(formulas)
417+
if !isempty(eqs)
418+
formulas = _get_formulas(eqs, obs_subs)
416419
_, f_ip = build_function(formulas, states, inputss..., params, iv; cse=false, expression)
417420
else
421+
formulas = []
418422
f_ip = nothing
419423
end
420424

421-
gformulas = [eq.rhs for eq in outeqs]
425+
# find all observable assigments necessary for outeqs
426+
gformulas = _get_formulas(outeqs, obs_subs)
422427
gformargs = if fftype isa PureFeedForward
423428
(inputss..., params, iv)
424429
elseif fftype isa FeedForward
@@ -437,8 +442,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
437442
end
438443

439444
# and the observed functions
440-
obsstates = [eq.lhs for eq in obseqs]
441-
obsformulas = [eq.rhs for eq in obseqs]
445+
obsstates = [eq.lhs for eq in obseqs_sorted]
446+
obsformulas = [eq.rhs for eq in obseqs_sorted]
442447
_, obsf_ip = build_function(obsformulas, states, inputss..., params, iv; cse=false, expression)
443448

444449
return (;
@@ -459,6 +464,36 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
459464
)
460465
end
461466

467+
function _get_formulas(eqs, obs_subs)
468+
rhss = [eq.rhs for eq in eqs]
469+
obsdeps = _collect_deps_on_obs(rhss, obs_subs)
470+
if isempty(obsdeps)
471+
return rhss
472+
else
473+
# ensure that the ordering is still correct
474+
obs_assignments = [Assignment(k, v) for (k,v) in obs_subs if k obsdeps]
475+
return [Let(obs_assignments, rhss[1], false), rhss[2:end]...]
476+
end
477+
end
478+
function _collect_deps_on_obs(terms, obs_subs)
479+
deps = Set{Symbolic}()
480+
for term in terms
481+
_collect_deps_on_obs!(deps, obs_subs, term)
482+
end
483+
deps
484+
end
485+
function _collect_deps_on_obs!(deps, obs_subs, term)
486+
termdeps = get_variables(term)
487+
for sym in termdeps
488+
if haskey(obs_subs, sym)
489+
# check recursively whether the observed depends on other observed
490+
_collect_deps_on_obs!(deps, obs_subs, obs_subs[sym])
491+
push!(deps, sym)
492+
end
493+
end
494+
deps
495+
end
496+
462497
function _determine_fftype(deps, states, allinputs, params, t)
463498
if isempty(allinputs deps) # no ff path
464499
if isempty(params deps) && !(t deps) # no p nor t

0 commit comments

Comments
 (0)