@@ -2,12 +2,14 @@ module NetworkDynamicsMTKExt
2
2
3
3
using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
4
4
using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
5
- using ModelingToolkit: full_equations, get_variables, structural_simplify, getname, unwrap
5
+ using ModelingToolkit: equations, full_equations, get_variables, structural_simplify, getname, unwrap
6
6
using ModelingToolkit: parameters, unknowns, independent_variables, observed, defaults
7
7
using Symbolics: Symbolics, fixpoint_sub, substitute
8
8
using RecursiveArrayTools: RecursiveArrayTools
9
9
using ArgCheck: @argcheck
10
10
using LinearAlgebra: Diagonal, I
11
+ using SymbolicUtils. Code: Let, Assignment
12
+ using OrderedCollections: OrderedDict
11
13
12
14
using NetworkDynamics: NetworkDynamics, set_metadata!,
13
15
PureFeedForward, FeedForward, NoFeedForward, PureStateMap
@@ -299,8 +301,10 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
299
301
params = setdiff (allparams, Set (allinputs))
300
302
301
303
# extract the main equations and observed equations
302
- eqs:: Vector{Equation} = ModelingToolkit. subs_constants (full_equations (sys))
304
+ eqs:: Vector{Equation} = ModelingToolkit. subs_constants (equations (sys))
305
+ obseqs_sorted:: Vector{Equation} = ModelingToolkit. subs_constants (observed (sys))
303
306
fix_metadata! (eqs, sys);
307
+ fix_metadata! (obseqs_sorted, sys);
304
308
305
309
# assert the ordering of states and equations
306
310
explicit_states = Symbolic[eq_type (eq)[2 ] for eq in eqs if ! isnothing (eq_type (eq)[2 ])]
@@ -311,8 +315,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
311
315
end
312
316
313
317
# check hat there are no rhs differentials in the equations
314
- if ! isempty (rhs_differentials (eqs))
315
- diffs = rhs_differentials (eqs)
318
+ if ! isempty (rhs_differentials (vcat ( eqs, obseqs_sorted) ))
319
+ diffs = rhs_differentials (vcat ( eqs, obseqs_sorted) )
316
320
buf = IOBuffer ()
317
321
println (buf, " Equations contain differentials in their rhs: " , diffs)
318
322
# for (i, eqs) in enumerate(eqs)
@@ -323,23 +327,17 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
323
327
throw (ArgumentError (String (take! (buf))))
324
328
end
325
329
326
- # extract observed equations. They might depend on eachother so resolve them
327
- obs_subs = Dict (eq. lhs => eq. rhs for eq in observed (sys))
328
- obseqs = map (observed (sys)) do eq
329
- expanded_rhs = fixpoint_sub (eq. rhs, obs_subs)
330
- eq. lhs ~ ModelingToolkit. subs_constants (expanded_rhs)
331
- end
332
- fix_metadata! (obseqs, sys);
333
330
# obs can only depend on parameters (including allinputs) or states
334
- obs_deps = _all_rhs_symbols (obseqs)
331
+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
332
+ obs_deps = _all_rhs_symbols (fixpoint_sub (obseqs_sorted, obs_subs))
335
333
if ! (obs_deps ⊆ Set (allparams) ∪ Set (states) ∪ independent_variables (sys))
336
334
@warn " obs_deps !⊆ parameters ∪ unknowns. Difference: $(setdiff (obs_deps, Set (allparams) ∪ Set (states))) "
337
335
end
338
336
339
337
# if some states shadow outputs (out ~ state in observed)
340
338
# switch their names. I.e. prioritize use of name `out`
341
339
renamings = Dict ()
342
- for eq in obseqs
340
+ for eq in obseqs_sorted
343
341
if eq. lhs ∈ Set (alloutputs) && iscall (eq. rhs) &&
344
342
operation (eq. rhs) isa Symbolics. BasicSymbolic && eq. rhs ∈ Set (states)
345
343
verbose && @info " Encountered trivial equation $eq . Swap out $(eq. lhs) <=> $(eq. rhs) everywhere."
@@ -349,7 +347,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
349
347
end
350
348
if ! isempty (renamings)
351
349
eqs = map (eq -> substitute (eq, renamings), eqs)
352
- obseqs = map (eq -> substitute (eq, renamings), obseqs)
350
+ obseqs_sorted = map (eq -> substitute (eq, renamings), obseqs_sorted)
351
+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
353
352
states = map (s -> substitute (s, renamings), states)
354
353
verbose && @info " New States:" states
355
354
end
@@ -360,16 +359,16 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
360
359
if out ∈ Set (states)
361
360
push! (outeqs, out ~ out)
362
361
else
363
- idx = findfirst (eq -> isequal (eq. lhs, out), obseqs )
362
+ idx = findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted )
364
363
if isnothing (idx)
365
364
throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
366
365
end
367
- eq = obseqs [idx]
366
+ eq = obseqs_sorted [idx]
368
367
if ! isempty (rhs_differentials (eq))
369
368
println (obs_subs[out])
370
369
throw (ArgumentError (" Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials (eq)) " ))
371
370
end
372
- deleteat! (obseqs , idx)
371
+ deleteat! (obseqs_sorted , idx)
373
372
374
373
if ff_to_constraint && ! isempty (get_variables (eq. rhs) ∩ allinputs)
375
374
verbose && @info " Output $out would lead to FF in g, promote to state instead."
@@ -382,6 +381,9 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
382
381
end
383
382
end
384
383
384
+ # obseqs might have changed in block above
385
+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
386
+
385
387
# generate mass matrix (this might change the equations)
386
388
mass_matrix = begin
387
389
# equations of form o = f(...) have to be transformed to 0 = f(...) - o
@@ -399,26 +401,29 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
399
401
end
400
402
401
403
iv = only (independent_variables (sys))
402
- out_deps = _all_rhs_symbols (outeqs)
404
+
405
+ out_deps = _all_rhs_symbols (fixpoint_sub (outeqs, obs_subs))
403
406
fftype = _determine_fftype (out_deps, states, allinputs, params, iv)
404
407
405
408
# filter out unnecessary parameters
406
- var_deps = _all_rhs_symbols (eqs)
409
+ var_deps = _all_rhs_symbols (fixpoint_sub ( eqs, obs_subs) )
407
410
unused_params = Set (setdiff (params, (var_deps ∪ out_deps))) # do not exclud obs_deps
408
411
if verbose && ! isempty (unused_params)
409
412
@info " Parameters $(unused_params) do not appear in equations of f and g and will be marked as unused."
410
413
end
411
414
412
415
# TODO : explore Symbolcs/SymbolicUtils CSE
413
416
# now generate the actual functions
414
- formulas = [eq . rhs for eq in eqs]
415
- if ! isempty ( formulas)
417
+ if ! isempty ( eqs)
418
+ formulas = _get_formulas (eqs, obs_subs )
416
419
_, f_ip = build_function (formulas, states, inputss... , params, iv; cse= false , expression)
417
420
else
421
+ formulas = []
418
422
f_ip = nothing
419
423
end
420
424
421
- gformulas = [eq. rhs for eq in outeqs]
425
+ # find all observable assigments necessary for outeqs
426
+ gformulas = _get_formulas (outeqs, obs_subs)
422
427
gformargs = if fftype isa PureFeedForward
423
428
(inputss... , params, iv)
424
429
elseif fftype isa FeedForward
@@ -437,8 +442,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
437
442
end
438
443
439
444
# and the observed functions
440
- obsstates = [eq. lhs for eq in obseqs ]
441
- obsformulas = [eq. rhs for eq in obseqs ]
445
+ obsstates = [eq. lhs for eq in obseqs_sorted ]
446
+ obsformulas = [eq. rhs for eq in obseqs_sorted ]
442
447
_, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
443
448
444
449
return (;
@@ -459,6 +464,36 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
459
464
)
460
465
end
461
466
467
+ function _get_formulas (eqs, obs_subs)
468
+ rhss = [eq. rhs for eq in eqs]
469
+ obsdeps = _collect_deps_on_obs (rhss, obs_subs)
470
+ if isempty (obsdeps)
471
+ return rhss
472
+ else
473
+ # ensure that the ordering is still correct
474
+ obs_assignments = [Assignment (k, v) for (k,v) in obs_subs if k ∈ obsdeps]
475
+ return [Let (obs_assignments, rhss[1 ], false ), rhss[2 : end ]. .. ]
476
+ end
477
+ end
478
+ function _collect_deps_on_obs (terms, obs_subs)
479
+ deps = Set {Symbolic} ()
480
+ for term in terms
481
+ _collect_deps_on_obs! (deps, obs_subs, term)
482
+ end
483
+ deps
484
+ end
485
+ function _collect_deps_on_obs! (deps, obs_subs, term)
486
+ termdeps = get_variables (term)
487
+ for sym in termdeps
488
+ if haskey (obs_subs, sym)
489
+ # check recursively whether the observed depends on other observed
490
+ _collect_deps_on_obs! (deps, obs_subs, obs_subs[sym])
491
+ push! (deps, sym)
492
+ end
493
+ end
494
+ deps
495
+ end
496
+
462
497
function _determine_fftype (deps, states, allinputs, params, t)
463
498
if isempty (allinputs ∩ deps) # no ff path
464
499
if isempty (params ∩ deps) && ! (t ∈ deps) # no p nor t
0 commit comments