@@ -2,12 +2,14 @@ module NetworkDynamicsMTKExt
22
33using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
44using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
5- using ModelingToolkit: full_equations, get_variables, structural_simplify, getname, unwrap
5+ using ModelingToolkit: equations, full_equations, get_variables, structural_simplify, getname, unwrap
66using ModelingToolkit: parameters, unknowns, independent_variables, observed, defaults
77using Symbolics: Symbolics, fixpoint_sub, substitute
88using RecursiveArrayTools: RecursiveArrayTools
99using ArgCheck: @argcheck
1010using LinearAlgebra: Diagonal, I
11+ using SymbolicUtils. Code: Let, Assignment
12+ using OrderedCollections: OrderedDict
1113
1214using NetworkDynamics: NetworkDynamics, set_metadata!,
1315 PureFeedForward, FeedForward, NoFeedForward, PureStateMap
@@ -299,8 +301,10 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
299301 params = setdiff (allparams, Set (allinputs))
300302
301303 # extract the main equations and observed equations
302- eqs:: Vector{Equation} = ModelingToolkit. subs_constants (full_equations (sys))
304+ eqs:: Vector{Equation} = ModelingToolkit. subs_constants (equations (sys))
305+ obseqs_sorted:: Vector{Equation} = ModelingToolkit. subs_constants (observed (sys))
303306 fix_metadata! (eqs, sys);
307+ fix_metadata! (obseqs_sorted, sys);
304308
305309 # assert the ordering of states and equations
306310 explicit_states = Symbolic[eq_type (eq)[2 ] for eq in eqs if ! isnothing (eq_type (eq)[2 ])]
@@ -311,8 +315,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
311315 end
312316
313317 # check hat there are no rhs differentials in the equations
314- if ! isempty (rhs_differentials (eqs))
315- diffs = rhs_differentials (eqs)
318+ if ! isempty (rhs_differentials (vcat ( eqs, obseqs_sorted) ))
319+ diffs = rhs_differentials (vcat ( eqs, obseqs_sorted) )
316320 buf = IOBuffer ()
317321 println (buf, " Equations contain differentials in their rhs: " , diffs)
318322 # for (i, eqs) in enumerate(eqs)
@@ -323,23 +327,17 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
323327 throw (ArgumentError (String (take! (buf))))
324328 end
325329
326- # extract observed equations. They might depend on eachother so resolve them
327- obs_subs = Dict (eq. lhs => eq. rhs for eq in observed (sys))
328- obseqs = map (observed (sys)) do eq
329- expanded_rhs = fixpoint_sub (eq. rhs, obs_subs)
330- eq. lhs ~ ModelingToolkit. subs_constants (expanded_rhs)
331- end
332- fix_metadata! (obseqs, sys);
333330 # obs can only depend on parameters (including allinputs) or states
334- obs_deps = _all_rhs_symbols (obseqs)
331+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
332+ obs_deps = _all_rhs_symbols (fixpoint_sub (obseqs_sorted, obs_subs))
335333 if ! (obs_deps ⊆ Set (allparams) ∪ Set (states) ∪ independent_variables (sys))
336334 @warn " obs_deps !⊆ parameters ∪ unknowns. Difference: $(setdiff (obs_deps, Set (allparams) ∪ Set (states))) "
337335 end
338336
339337 # if some states shadow outputs (out ~ state in observed)
340338 # switch their names. I.e. prioritize use of name `out`
341339 renamings = Dict ()
342- for eq in obseqs
340+ for eq in obseqs_sorted
343341 if eq. lhs ∈ Set (alloutputs) && iscall (eq. rhs) &&
344342 operation (eq. rhs) isa Symbolics. BasicSymbolic && eq. rhs ∈ Set (states)
345343 verbose && @info " Encountered trivial equation $eq . Swap out $(eq. lhs) <=> $(eq. rhs) everywhere."
@@ -349,7 +347,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
349347 end
350348 if ! isempty (renamings)
351349 eqs = map (eq -> substitute (eq, renamings), eqs)
352- obseqs = map (eq -> substitute (eq, renamings), obseqs)
350+ obseqs_sorted = map (eq -> substitute (eq, renamings), obseqs_sorted)
351+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
353352 states = map (s -> substitute (s, renamings), states)
354353 verbose && @info " New States:" states
355354 end
@@ -360,16 +359,16 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
360359 if out ∈ Set (states)
361360 push! (outeqs, out ~ out)
362361 else
363- idx = findfirst (eq -> isequal (eq. lhs, out), obseqs )
362+ idx = findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted )
364363 if isnothing (idx)
365364 throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
366365 end
367- eq = obseqs [idx]
366+ eq = obseqs_sorted [idx]
368367 if ! isempty (rhs_differentials (eq))
369368 println (obs_subs[out])
370369 throw (ArgumentError (" Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials (eq)) " ))
371370 end
372- deleteat! (obseqs , idx)
371+ deleteat! (obseqs_sorted , idx)
373372
374373 if ff_to_constraint && ! isempty (get_variables (eq. rhs) ∩ allinputs)
375374 verbose && @info " Output $out would lead to FF in g, promote to state instead."
@@ -382,6 +381,9 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
382381 end
383382 end
384383
384+ # obseqs might have changed in block above
385+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
386+
385387 # generate mass matrix (this might change the equations)
386388 mass_matrix = begin
387389 # equations of form o = f(...) have to be transformed to 0 = f(...) - o
@@ -399,26 +401,29 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
399401 end
400402
401403 iv = only (independent_variables (sys))
402- out_deps = _all_rhs_symbols (outeqs)
404+
405+ out_deps = _all_rhs_symbols (fixpoint_sub (outeqs, obs_subs))
403406 fftype = _determine_fftype (out_deps, states, allinputs, params, iv)
404407
405408 # filter out unnecessary parameters
406- var_deps = _all_rhs_symbols (eqs)
409+ var_deps = _all_rhs_symbols (fixpoint_sub ( eqs, obs_subs) )
407410 unused_params = Set (setdiff (params, (var_deps ∪ out_deps))) # do not exclud obs_deps
408411 if verbose && ! isempty (unused_params)
409412 @info " Parameters $(unused_params) do not appear in equations of f and g and will be marked as unused."
410413 end
411414
412415 # TODO : explore Symbolcs/SymbolicUtils CSE
413416 # now generate the actual functions
414- formulas = [eq . rhs for eq in eqs]
415- if ! isempty ( formulas)
417+ if ! isempty ( eqs)
418+ formulas = _get_formulas (eqs, obs_subs )
416419 _, f_ip = build_function (formulas, states, inputss... , params, iv; cse= false , expression)
417420 else
421+ formulas = []
418422 f_ip = nothing
419423 end
420424
421- gformulas = [eq. rhs for eq in outeqs]
425+ # find all observable assigments necessary for outeqs
426+ gformulas = _get_formulas (outeqs, obs_subs)
422427 gformargs = if fftype isa PureFeedForward
423428 (inputss... , params, iv)
424429 elseif fftype isa FeedForward
@@ -437,8 +442,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
437442 end
438443
439444 # and the observed functions
440- obsstates = [eq. lhs for eq in obseqs ]
441- obsformulas = [eq. rhs for eq in obseqs ]
445+ obsstates = [eq. lhs for eq in obseqs_sorted ]
446+ obsformulas = [eq. rhs for eq in obseqs_sorted ]
442447 _, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
443448
444449 return (;
@@ -459,6 +464,36 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
459464 )
460465end
461466
467+ function _get_formulas (eqs, obs_subs)
468+ rhss = [eq. rhs for eq in eqs]
469+ obsdeps = _collect_deps_on_obs (rhss, obs_subs)
470+ if isempty (obsdeps)
471+ return rhss
472+ else
473+ # ensure that the ordering is still correct
474+ obs_assignments = [Assignment (k, v) for (k,v) in obs_subs if k ∈ obsdeps]
475+ return [Let (obs_assignments, rhss[1 ], false ), rhss[2 : end ]. .. ]
476+ end
477+ end
478+ function _collect_deps_on_obs (terms, obs_subs)
479+ deps = Set {Symbolic} ()
480+ for term in terms
481+ _collect_deps_on_obs! (deps, obs_subs, term)
482+ end
483+ deps
484+ end
485+ function _collect_deps_on_obs! (deps, obs_subs, term)
486+ termdeps = get_variables (term)
487+ for sym in termdeps
488+ if haskey (obs_subs, sym)
489+ # check recursively whether the observed depends on other observed
490+ _collect_deps_on_obs! (deps, obs_subs, obs_subs[sym])
491+ push! (deps, sym)
492+ end
493+ end
494+ deps
495+ end
496+
462497function _determine_fftype (deps, states, allinputs, params, t)
463498 if isempty (allinputs ∩ deps) # no ff path
464499 if isempty (params ∩ deps) && ! (t ∈ deps) # no p nor t
0 commit comments