Skip to content

Commit fc510a9

Browse files
authored
Merge pull request #234 from JuliaDynamics/hw/cse_observed
2 parents b8c865b + 5628896 commit fc510a9

File tree

2 files changed

+99
-46
lines changed

2 files changed

+99
-46
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1717
Mixers = "2a8e4939-dab8-5edc-8f64-72a8776f13de"
1818
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1919
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
20+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2021
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
2122
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
2223
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -37,11 +38,12 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3738
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3839
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
3940
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
41+
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
4042
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4143

4244
[extensions]
4345
NetworkDynamicsCUDAExt = ["CUDA", "Adapt"]
44-
NetworkDynamicsMTKExt = ["ModelingToolkit", "Symbolics"]
46+
NetworkDynamicsMTKExt = ["ModelingToolkit", "SymbolicUtils", "Symbolics"]
4547
NetworkDynamicsSymbolicsExt = ["Symbolics", "MacroTools"]
4648

4749
[compat]
@@ -62,6 +64,7 @@ Mixers = "0.1.2"
6264
ModelingToolkit = "9.67"
6365
NNlib = "0.9.13"
6466
NonlinearSolve = "4"
67+
OrderedCollections = "1.8.0"
6568
Polyester = "0.7.12"
6669
PreallocationTools = "0.4.23"
6770
PrecompileTools = "1.2.1"
@@ -75,6 +78,7 @@ StaticArrays = "1.9.4"
7578
SteadyStateDiffEq = "2.2.0"
7679
StyledStrings = "1.0.3"
7780
SymbolicIndexingInterface = "0.3.27"
81+
SymbolicUtils = "3.24"
7882
Symbolics = "6.19.0"
7983
TimerOutputs = "0.5.23"
8084
julia = "1.10"

ext/NetworkDynamicsMTKExt.jl

Lines changed: 94 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ module NetworkDynamicsMTKExt
22

33
using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
44
using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
5-
using ModelingToolkit: full_equations, get_variables, structural_simplify, getname, unwrap
5+
using ModelingToolkit: equations, full_equations, get_variables, structural_simplify, getname, unwrap
66
using ModelingToolkit: parameters, unknowns, independent_variables, observed, defaults
77
using Symbolics: Symbolics, fixpoint_sub, substitute
88
using RecursiveArrayTools: RecursiveArrayTools
99
using ArgCheck: @argcheck
1010
using LinearAlgebra: Diagonal, I
11+
using SymbolicUtils.Code: Let, Assignment
12+
using OrderedCollections: OrderedDict
1113

1214
using NetworkDynamics: NetworkDynamics, set_metadata!,
1315
PureFeedForward, FeedForward, NoFeedForward, PureStateMap
@@ -299,8 +301,10 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
299301
params = setdiff(allparams, Set(allinputs))
300302

301303
# extract the main equations and observed equations
302-
eqs::Vector{Equation} = ModelingToolkit.subs_constants(full_equations(sys))
304+
eqs::Vector{Equation} = ModelingToolkit.subs_constants(equations(sys))
305+
obseqs_sorted::Vector{Equation} = ModelingToolkit.subs_constants(observed(sys))
303306
fix_metadata!(eqs, sys);
307+
fix_metadata!(obseqs_sorted, sys);
304308

305309
# assert the ordering of states and equations
306310
explicit_states = Symbolic[eq_type(eq)[2] for eq in eqs if !isnothing(eq_type(eq)[2])]
@@ -311,8 +315,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
311315
end
312316

313317
# check hat there are no rhs differentials in the equations
314-
if !isempty(rhs_differentials(eqs))
315-
diffs = rhs_differentials(eqs)
318+
if !isempty(rhs_differentials(vcat(eqs, obseqs_sorted)))
319+
diffs = rhs_differentials(vcat(eqs, obseqs_sorted))
316320
buf = IOBuffer()
317321
println(buf, "Equations contain differentials in their rhs: ", diffs)
318322
# for (i, eqs) in enumerate(eqs)
@@ -323,23 +327,17 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
323327
throw(ArgumentError(String(take!(buf))))
324328
end
325329

326-
# extract observed equations. They might depend on eachother so resolve them
327-
obs_subs = Dict(eq.lhs => eq.rhs for eq in observed(sys))
328-
obseqs = map(observed(sys)) do eq
329-
expanded_rhs = fixpoint_sub(eq.rhs, obs_subs)
330-
eq.lhs ~ ModelingToolkit.subs_constants(expanded_rhs)
331-
end
332-
fix_metadata!(obseqs, sys);
333330
# obs can only depend on parameters (including allinputs) or states
334-
obs_deps = _all_rhs_symbols(obseqs)
331+
obs_subs = OrderedDict(eq.lhs => eq.rhs for eq in obseqs_sorted)
332+
obs_deps = _all_rhs_symbols(fixpoint_sub(obseqs_sorted, obs_subs))
335333
if !(obs_deps Set(allparams) Set(states) independent_variables(sys))
336334
@warn "obs_deps !⊆ parameters ∪ unknowns. Difference: $(setdiff(obs_deps, Set(allparams) Set(states)))"
337335
end
338336

339337
# if some states shadow outputs (out ~ state in observed)
340338
# switch their names. I.e. prioritize use of name `out`
341339
renamings = Dict()
342-
for eq in obseqs
340+
for eq in obseqs_sorted
343341
if eq.lhs Set(alloutputs) && iscall(eq.rhs) &&
344342
operation(eq.rhs) isa Symbolics.BasicSymbolic && eq.rhs Set(states)
345343
verbose && @info "Encountered trivial equation $eq. Swap out $(eq.lhs) <=> $(eq.rhs) everywhere."
@@ -349,36 +347,37 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
349347
end
350348
if !isempty(renamings)
351349
eqs = map(eq -> substitute(eq, renamings), eqs)
352-
obseqs = map(eq -> substitute(eq, renamings), obseqs)
350+
obseqs_sorted = map(eq -> substitute(eq, renamings), obseqs_sorted)
351+
obs_subs = OrderedDict(eq.lhs => eq.rhs for eq in obseqs_sorted)
353352
states = map(s -> substitute(s, renamings), states)
354353
verbose && @info "New States:" states
355354
end
356355

357-
# find the output equations, this might remove them from obseqs!
356+
# find the output equations, this might remove them from obseqs_sorted (obs_subs stays intact)
358357
outeqs = Equation[]
359358
for out in Iterators.flatten(outputss)
360359
if out Set(states)
361360
push!(outeqs, out ~ out)
362-
else
363-
idx = findfirst(eq -> isequal(eq.lhs, out), obseqs)
364-
if isnothing(idx)
365-
throw(ArgumentError("Output $out was neither foundin states nor in observed equations."))
366-
end
367-
eq = obseqs[idx]
368-
if !isempty(rhs_differentials(eq))
369-
println(obs_subs[out])
370-
throw(ArgumentError("Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials(eq))"))
371-
end
372-
deleteat!(obseqs, idx)
373-
374-
if ff_to_constraint && !isempty(get_variables(eq.rhs) allinputs)
361+
elseif out keys(obs_subs)
362+
# if its a observed, we need to check for ff behavior
363+
fulleq = out ~ fixpoint_sub(obs_subs[out], obs_subs)
364+
if ff_to_constraint && !isempty(get_variables(fulleq.rhs) allinputs)
375365
verbose && @info "Output $out would lead to FF in g, promote to state instead."
376-
push!(eqs, 0 ~ eq.lhs - eq.rhs)
377-
push!(states, eq.lhs)
378-
push!(outeqs, eq.lhs ~ eq.lhs)
379-
else
380-
push!(outeqs, eq)
366+
# not observed anymore, delete from observed and put in equations
367+
push!(eqs, 0 ~ out - obs_subs[out])
368+
push!(states, out)
369+
push!(outeqs, out ~ out)
370+
371+
deleteat!(obseqs_sorted, findfirst(eq -> isequal(eq.lhs, out), obseqs_sorted))
372+
delete!(obs_subs, out)
373+
else # "normal" observed state
374+
push!(outeqs, out ~ obs_subs[out])
375+
# delete from obs equations but *not* from obs_subs (otherwise can't be reference)
376+
# in equations
377+
deleteat!(obseqs_sorted, findfirst(eq -> isequal(eq.lhs, out), obseqs_sorted))
381378
end
379+
else
380+
throw(ArgumentError("Output $out was neither foundin states nor in observed equations."))
382381
end
383382
end
384383

@@ -399,26 +398,28 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
399398
end
400399

401400
iv = only(independent_variables(sys))
402-
out_deps = _all_rhs_symbols(outeqs)
401+
402+
out_deps = _all_rhs_symbols(fixpoint_sub(outeqs, obs_subs))
403403
fftype = _determine_fftype(out_deps, states, allinputs, params, iv)
404404

405405
# filter out unnecessary parameters
406-
var_deps = _all_rhs_symbols(eqs)
406+
var_deps = _all_rhs_symbols(fixpoint_sub(eqs, obs_subs))
407407
unused_params = Set(setdiff(params, (var_deps out_deps))) # do not exclud obs_deps
408408
if verbose && !isempty(unused_params)
409409
@info "Parameters $(unused_params) do not appear in equations of f and g and will be marked as unused."
410410
end
411411

412412
# TODO: explore Symbolcs/SymbolicUtils CSE
413413
# now generate the actual functions
414-
formulas = [eq.rhs for eq in eqs]
415-
if !isempty(formulas)
414+
if !isempty(eqs)
415+
formulas = _get_formulas(eqs, obs_subs)
416416
_, f_ip = build_function(formulas, states, inputss..., params, iv; cse=false, expression)
417417
else
418418
f_ip = nothing
419419
end
420420

421-
gformulas = [eq.rhs for eq in outeqs]
421+
# find all observable assigments necessary for outeqs
422+
gformulas = _get_formulas(outeqs, obs_subs)
422423
gformargs = if fftype isa PureFeedForward
423424
(inputss..., params, iv)
424425
elseif fftype isa FeedForward
@@ -436,10 +437,13 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
436437
MultipleOutputWrapper{fftype, length(outputss), typeof(_g_ip)}(_g_ip)
437438
end
438439

439-
# and the observed functions
440-
obsstates = [eq.lhs for eq in obseqs]
441-
obsformulas = [eq.rhs for eq in obseqs]
442-
_, obsf_ip = build_function(obsformulas, states, inputss..., params, iv; cse=false, expression)
440+
obsstates = [eq.lhs for eq in obseqs_sorted]
441+
if !isempty(obsstates)
442+
obsformulas = _get_formulas([s ~ s for s in obsstates], obs_subs)
443+
_, obsf_ip = build_function(obsformulas, states, inputss..., params, iv; cse=false, expression)
444+
else
445+
obsf_ip = nothing
446+
end
443447

444448
return (;
445449
f=f_ip, g=g_ip,
@@ -450,15 +454,60 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
450454
obsstates,
451455
fftype,
452456
obsf = obsf_ip,
453-
equations=formulas,
454-
outputeqs=Dict(Iterators.flatten(outputss) .=> gformulas),
455-
observed=Dict(getname.(obsstates) .=> obsformulas),
457+
equations=eqs,
458+
outputeqs=outeqs,
459+
observed=obseqs_sorted,
456460
odesystem=sys,
457461
params,
458462
unused_params
459463
)
460464
end
461465

466+
function _get_formulas(eqs, obs_subs)
467+
# Bit hacky, were building a function like this,
468+
# where all (necessary) obs and eqs are contained in the bgin block of the first output
469+
# out[1] = begin
470+
# obs1 = ...
471+
# obs2 = ...
472+
# ...
473+
# state1 = ...
474+
# state2 = ...
475+
# ...
476+
# state1 # ens up in out[1]
477+
# end
478+
# out[2] = state2
479+
# ...
480+
isempty(eqs) && return []
481+
obsdeps = _collect_deps_on_obs([eq.rhs for eq in eqs], obs_subs)
482+
obs_assignments = [Assignment(k, v) for (k,v) in obs_subs if k obsdeps]
483+
484+
# implicit equations are not use via assigments, so we filter for e
485+
eqs_assignments = [Assignment(eq.lhs, eq.rhs) for eq in eqs
486+
if !isequal(eq.lhs, eq.rhs) && !isequal(eq.lhs, 0)]
487+
# since implicit eqs did not end up in assighmets, we use the rhs
488+
out = [isequal(eq.lhs, 0) ? eq.rhs : eq.lhs for eq in eqs]
489+
490+
[Let(vcat(obs_assignments, eqs_assignments), out[1], false), out[2:end]...]
491+
end
492+
function _collect_deps_on_obs(terms, obs_subs)
493+
deps = Set{Symbolic}()
494+
for term in terms
495+
_collect_deps_on_obs!(deps, obs_subs, term)
496+
end
497+
deps
498+
end
499+
function _collect_deps_on_obs!(deps, obs_subs, term)
500+
termdeps = get_variables(term)
501+
for sym in termdeps
502+
if haskey(obs_subs, sym)
503+
# check recursively whether the observed depends on other observed
504+
_collect_deps_on_obs!(deps, obs_subs, obs_subs[sym])
505+
push!(deps, sym)
506+
end
507+
end
508+
deps
509+
end
510+
462511
function _determine_fftype(deps, states, allinputs, params, t)
463512
if isempty(allinputs deps) # no ff path
464513
if isempty(params deps) && !(t deps) # no p nor t

0 commit comments

Comments
 (0)