@@ -2,12 +2,14 @@ module NetworkDynamicsMTKExt
22
33using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
44using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
5- using ModelingToolkit: full_equations, get_variables, structural_simplify, getname, unwrap
5+ using ModelingToolkit: equations, full_equations, get_variables, structural_simplify, getname, unwrap
66using ModelingToolkit: parameters, unknowns, independent_variables, observed, defaults
77using Symbolics: Symbolics, fixpoint_sub, substitute
88using RecursiveArrayTools: RecursiveArrayTools
99using ArgCheck: @argcheck
1010using LinearAlgebra: Diagonal, I
11+ using SymbolicUtils. Code: Let, Assignment
12+ using OrderedCollections: OrderedDict
1113
1214using NetworkDynamics: NetworkDynamics, set_metadata!,
1315 PureFeedForward, FeedForward, NoFeedForward, PureStateMap
@@ -299,8 +301,10 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
299301 params = setdiff (allparams, Set (allinputs))
300302
301303 # extract the main equations and observed equations
302- eqs:: Vector{Equation} = ModelingToolkit. subs_constants (full_equations (sys))
304+ eqs:: Vector{Equation} = ModelingToolkit. subs_constants (equations (sys))
305+ obseqs_sorted:: Vector{Equation} = ModelingToolkit. subs_constants (observed (sys))
303306 fix_metadata! (eqs, sys);
307+ fix_metadata! (obseqs_sorted, sys);
304308
305309 # assert the ordering of states and equations
306310 explicit_states = Symbolic[eq_type (eq)[2 ] for eq in eqs if ! isnothing (eq_type (eq)[2 ])]
@@ -311,8 +315,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
311315 end
312316
313317 # check hat there are no rhs differentials in the equations
314- if ! isempty (rhs_differentials (eqs))
315- diffs = rhs_differentials (eqs)
318+ if ! isempty (rhs_differentials (vcat ( eqs, obseqs_sorted) ))
319+ diffs = rhs_differentials (vcat ( eqs, obseqs_sorted) )
316320 buf = IOBuffer ()
317321 println (buf, " Equations contain differentials in their rhs: " , diffs)
318322 # for (i, eqs) in enumerate(eqs)
@@ -323,23 +327,17 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
323327 throw (ArgumentError (String (take! (buf))))
324328 end
325329
326- # extract observed equations. They might depend on eachother so resolve them
327- obs_subs = Dict (eq. lhs => eq. rhs for eq in observed (sys))
328- obseqs = map (observed (sys)) do eq
329- expanded_rhs = fixpoint_sub (eq. rhs, obs_subs)
330- eq. lhs ~ ModelingToolkit. subs_constants (expanded_rhs)
331- end
332- fix_metadata! (obseqs, sys);
333330 # obs can only depend on parameters (including allinputs) or states
334- obs_deps = _all_rhs_symbols (obseqs)
331+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
332+ obs_deps = _all_rhs_symbols (fixpoint_sub (obseqs_sorted, obs_subs))
335333 if ! (obs_deps ⊆ Set (allparams) ∪ Set (states) ∪ independent_variables (sys))
336334 @warn " obs_deps !⊆ parameters ∪ unknowns. Difference: $(setdiff (obs_deps, Set (allparams) ∪ Set (states))) "
337335 end
338336
339337 # if some states shadow outputs (out ~ state in observed)
340338 # switch their names. I.e. prioritize use of name `out`
341339 renamings = Dict ()
342- for eq in obseqs
340+ for eq in obseqs_sorted
343341 if eq. lhs ∈ Set (alloutputs) && iscall (eq. rhs) &&
344342 operation (eq. rhs) isa Symbolics. BasicSymbolic && eq. rhs ∈ Set (states)
345343 verbose && @info " Encountered trivial equation $eq . Swap out $(eq. lhs) <=> $(eq. rhs) everywhere."
@@ -349,36 +347,37 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
349347 end
350348 if ! isempty (renamings)
351349 eqs = map (eq -> substitute (eq, renamings), eqs)
352- obseqs = map (eq -> substitute (eq, renamings), obseqs)
350+ obseqs_sorted = map (eq -> substitute (eq, renamings), obseqs_sorted)
351+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
353352 states = map (s -> substitute (s, renamings), states)
354353 verbose && @info " New States:" states
355354 end
356355
357- # find the output equations, this might remove them from obseqs!
356+ # find the output equations, this might remove them from obseqs_sorted (obs_subs stays intact)
358357 outeqs = Equation[]
359358 for out in Iterators. flatten (outputss)
360359 if out ∈ Set (states)
361360 push! (outeqs, out ~ out)
362- else
363- idx = findfirst (eq -> isequal (eq. lhs, out), obseqs)
364- if isnothing (idx)
365- throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
366- end
367- eq = obseqs[idx]
368- if ! isempty (rhs_differentials (eq))
369- println (obs_subs[out])
370- throw (ArgumentError (" Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials (eq)) " ))
371- end
372- deleteat! (obseqs, idx)
373-
374- if ff_to_constraint && ! isempty (get_variables (eq. rhs) ∩ allinputs)
361+ elseif out ∈ keys (obs_subs)
362+ # if its a observed, we need to check for ff behavior
363+ fulleq = out ~ fixpoint_sub (obs_subs[out], obs_subs)
364+ if ff_to_constraint && ! isempty (get_variables (fulleq. rhs) ∩ allinputs)
375365 verbose && @info " Output $out would lead to FF in g, promote to state instead."
376- push! (eqs, 0 ~ eq. lhs - eq. rhs)
377- push! (states, eq. lhs)
378- push! (outeqs, eq. lhs ~ eq. lhs)
379- else
380- push! (outeqs, eq)
366+ # not observed anymore, delete from observed and put in equations
367+ push! (eqs, 0 ~ out - obs_subs[out])
368+ push! (states, out)
369+ push! (outeqs, out ~ out)
370+
371+ deleteat! (obseqs_sorted, findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted))
372+ delete! (obs_subs, out)
373+ else # "normal" observed state
374+ push! (outeqs, out ~ obs_subs[out])
375+ # delete from obs equations but *not* from obs_subs (otherwise can't be reference)
376+ # in equations
377+ deleteat! (obseqs_sorted, findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted))
381378 end
379+ else
380+ throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
382381 end
383382 end
384383
@@ -399,26 +398,28 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
399398 end
400399
401400 iv = only (independent_variables (sys))
402- out_deps = _all_rhs_symbols (outeqs)
401+
402+ out_deps = _all_rhs_symbols (fixpoint_sub (outeqs, obs_subs))
403403 fftype = _determine_fftype (out_deps, states, allinputs, params, iv)
404404
405405 # filter out unnecessary parameters
406- var_deps = _all_rhs_symbols (eqs)
406+ var_deps = _all_rhs_symbols (fixpoint_sub ( eqs, obs_subs) )
407407 unused_params = Set (setdiff (params, (var_deps ∪ out_deps))) # do not exclud obs_deps
408408 if verbose && ! isempty (unused_params)
409409 @info " Parameters $(unused_params) do not appear in equations of f and g and will be marked as unused."
410410 end
411411
412412 # TODO : explore Symbolcs/SymbolicUtils CSE
413413 # now generate the actual functions
414- formulas = [eq . rhs for eq in eqs]
415- if ! isempty ( formulas)
414+ if ! isempty ( eqs)
415+ formulas = _get_formulas (eqs, obs_subs )
416416 _, f_ip = build_function (formulas, states, inputss... , params, iv; cse= false , expression)
417417 else
418418 f_ip = nothing
419419 end
420420
421- gformulas = [eq. rhs for eq in outeqs]
421+ # find all observable assigments necessary for outeqs
422+ gformulas = _get_formulas (outeqs, obs_subs)
422423 gformargs = if fftype isa PureFeedForward
423424 (inputss... , params, iv)
424425 elseif fftype isa FeedForward
@@ -436,10 +437,13 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
436437 MultipleOutputWrapper {fftype, length(outputss), typeof(_g_ip)} (_g_ip)
437438 end
438439
439- # and the observed functions
440- obsstates = [eq. lhs for eq in obseqs]
441- obsformulas = [eq. rhs for eq in obseqs]
442- _, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
440+ obsstates = [eq. lhs for eq in obseqs_sorted]
441+ if ! isempty (obsstates)
442+ obsformulas = _get_formulas ([s ~ s for s in obsstates], obs_subs)
443+ _, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
444+ else
445+ obsf_ip = nothing
446+ end
443447
444448 return (;
445449 f= f_ip, g= g_ip,
@@ -450,15 +454,60 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
450454 obsstates,
451455 fftype,
452456 obsf = obsf_ip,
453- equations= formulas ,
454- outputeqs= Dict (Iterators . flatten (outputss) .=> gformulas) ,
455- observed= Dict ( getname .(obsstates) .=> obsformulas) ,
457+ equations= eqs ,
458+ outputeqs= outeqs ,
459+ observed= obseqs_sorted ,
456460 odesystem= sys,
457461 params,
458462 unused_params
459463 )
460464end
461465
466+ function _get_formulas (eqs, obs_subs)
467+ # Bit hacky, were building a function like this,
468+ # where all (necessary) obs and eqs are contained in the bgin block of the first output
469+ # out[1] = begin
470+ # obs1 = ...
471+ # obs2 = ...
472+ # ...
473+ # state1 = ...
474+ # state2 = ...
475+ # ...
476+ # state1 # ens up in out[1]
477+ # end
478+ # out[2] = state2
479+ # ...
480+ isempty (eqs) && return []
481+ obsdeps = _collect_deps_on_obs ([eq. rhs for eq in eqs], obs_subs)
482+ obs_assignments = [Assignment (k, v) for (k,v) in obs_subs if k ∈ obsdeps]
483+
484+ # implicit equations are not use via assigments, so we filter for e
485+ eqs_assignments = [Assignment (eq. lhs, eq. rhs) for eq in eqs
486+ if ! isequal (eq. lhs, eq. rhs) && ! isequal (eq. lhs, 0 )]
487+ # since implicit eqs did not end up in assighmets, we use the rhs
488+ out = [isequal (eq. lhs, 0 ) ? eq. rhs : eq. lhs for eq in eqs]
489+
490+ [Let (vcat (obs_assignments, eqs_assignments), out[1 ], false ), out[2 : end ]. .. ]
491+ end
492+ function _collect_deps_on_obs (terms, obs_subs)
493+ deps = Set {Symbolic} ()
494+ for term in terms
495+ _collect_deps_on_obs! (deps, obs_subs, term)
496+ end
497+ deps
498+ end
499+ function _collect_deps_on_obs! (deps, obs_subs, term)
500+ termdeps = get_variables (term)
501+ for sym in termdeps
502+ if haskey (obs_subs, sym)
503+ # check recursively whether the observed depends on other observed
504+ _collect_deps_on_obs! (deps, obs_subs, obs_subs[sym])
505+ push! (deps, sym)
506+ end
507+ end
508+ deps
509+ end
510+
462511function _determine_fftype (deps, states, allinputs, params, t)
463512 if isempty (allinputs ∩ deps) # no ff path
464513 if isempty (params ∩ deps) && ! (t ∈ deps) # no p nor t
0 commit comments