@@ -2,12 +2,14 @@ module NetworkDynamicsMTKExt
2
2
3
3
using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
4
4
using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
5
- using ModelingToolkit: full_equations, get_variables, structural_simplify, getname, unwrap
5
+ using ModelingToolkit: equations, full_equations, get_variables, structural_simplify, getname, unwrap
6
6
using ModelingToolkit: parameters, unknowns, independent_variables, observed, defaults
7
7
using Symbolics: Symbolics, fixpoint_sub, substitute
8
8
using RecursiveArrayTools: RecursiveArrayTools
9
9
using ArgCheck: @argcheck
10
10
using LinearAlgebra: Diagonal, I
11
+ using SymbolicUtils. Code: Let, Assignment
12
+ using OrderedCollections: OrderedDict
11
13
12
14
using NetworkDynamics: NetworkDynamics, set_metadata!,
13
15
PureFeedForward, FeedForward, NoFeedForward, PureStateMap
@@ -299,8 +301,10 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
299
301
params = setdiff (allparams, Set (allinputs))
300
302
301
303
# extract the main equations and observed equations
302
- eqs:: Vector{Equation} = ModelingToolkit. subs_constants (full_equations (sys))
304
+ eqs:: Vector{Equation} = ModelingToolkit. subs_constants (equations (sys))
305
+ obseqs_sorted:: Vector{Equation} = ModelingToolkit. subs_constants (observed (sys))
303
306
fix_metadata! (eqs, sys);
307
+ fix_metadata! (obseqs_sorted, sys);
304
308
305
309
# assert the ordering of states and equations
306
310
explicit_states = Symbolic[eq_type (eq)[2 ] for eq in eqs if ! isnothing (eq_type (eq)[2 ])]
@@ -311,8 +315,8 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
311
315
end
312
316
313
317
# check hat there are no rhs differentials in the equations
314
- if ! isempty (rhs_differentials (eqs))
315
- diffs = rhs_differentials (eqs)
318
+ if ! isempty (rhs_differentials (vcat ( eqs, obseqs_sorted) ))
319
+ diffs = rhs_differentials (vcat ( eqs, obseqs_sorted) )
316
320
buf = IOBuffer ()
317
321
println (buf, " Equations contain differentials in their rhs: " , diffs)
318
322
# for (i, eqs) in enumerate(eqs)
@@ -323,23 +327,17 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
323
327
throw (ArgumentError (String (take! (buf))))
324
328
end
325
329
326
- # extract observed equations. They might depend on eachother so resolve them
327
- obs_subs = Dict (eq. lhs => eq. rhs for eq in observed (sys))
328
- obseqs = map (observed (sys)) do eq
329
- expanded_rhs = fixpoint_sub (eq. rhs, obs_subs)
330
- eq. lhs ~ ModelingToolkit. subs_constants (expanded_rhs)
331
- end
332
- fix_metadata! (obseqs, sys);
333
330
# obs can only depend on parameters (including allinputs) or states
334
- obs_deps = _all_rhs_symbols (obseqs)
331
+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
332
+ obs_deps = _all_rhs_symbols (fixpoint_sub (obseqs_sorted, obs_subs))
335
333
if ! (obs_deps ⊆ Set (allparams) ∪ Set (states) ∪ independent_variables (sys))
336
334
@warn " obs_deps !⊆ parameters ∪ unknowns. Difference: $(setdiff (obs_deps, Set (allparams) ∪ Set (states))) "
337
335
end
338
336
339
337
# if some states shadow outputs (out ~ state in observed)
340
338
# switch their names. I.e. prioritize use of name `out`
341
339
renamings = Dict ()
342
- for eq in obseqs
340
+ for eq in obseqs_sorted
343
341
if eq. lhs ∈ Set (alloutputs) && iscall (eq. rhs) &&
344
342
operation (eq. rhs) isa Symbolics. BasicSymbolic && eq. rhs ∈ Set (states)
345
343
verbose && @info " Encountered trivial equation $eq . Swap out $(eq. lhs) <=> $(eq. rhs) everywhere."
@@ -349,36 +347,37 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
349
347
end
350
348
if ! isempty (renamings)
351
349
eqs = map (eq -> substitute (eq, renamings), eqs)
352
- obseqs = map (eq -> substitute (eq, renamings), obseqs)
350
+ obseqs_sorted = map (eq -> substitute (eq, renamings), obseqs_sorted)
351
+ obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
353
352
states = map (s -> substitute (s, renamings), states)
354
353
verbose && @info " New States:" states
355
354
end
356
355
357
- # find the output equations, this might remove them from obseqs!
356
+ # find the output equations, this might remove them from obseqs_sorted (obs_subs stays intact)
358
357
outeqs = Equation[]
359
358
for out in Iterators. flatten (outputss)
360
359
if out ∈ Set (states)
361
360
push! (outeqs, out ~ out)
362
- else
363
- idx = findfirst (eq -> isequal (eq. lhs, out), obseqs)
364
- if isnothing (idx)
365
- throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
366
- end
367
- eq = obseqs[idx]
368
- if ! isempty (rhs_differentials (eq))
369
- println (obs_subs[out])
370
- throw (ArgumentError (" Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials (eq)) " ))
371
- end
372
- deleteat! (obseqs, idx)
373
-
374
- if ff_to_constraint && ! isempty (get_variables (eq. rhs) ∩ allinputs)
361
+ elseif out ∈ keys (obs_subs)
362
+ # if its a observed, we need to check for ff behavior
363
+ fulleq = out ~ fixpoint_sub (obs_subs[out], obs_subs)
364
+ if ff_to_constraint && ! isempty (get_variables (fulleq. rhs) ∩ allinputs)
375
365
verbose && @info " Output $out would lead to FF in g, promote to state instead."
376
- push! (eqs, 0 ~ eq. lhs - eq. rhs)
377
- push! (states, eq. lhs)
378
- push! (outeqs, eq. lhs ~ eq. lhs)
379
- else
380
- push! (outeqs, eq)
366
+ # not observed anymore, delete from observed and put in equations
367
+ push! (eqs, 0 ~ out - obs_subs[out])
368
+ push! (states, out)
369
+ push! (outeqs, out ~ out)
370
+
371
+ deleteat! (obseqs_sorted, findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted))
372
+ delete! (obs_subs, out)
373
+ else # "normal" observed state
374
+ push! (outeqs, out ~ obs_subs[out])
375
+ # delete from obs equations but *not* from obs_subs (otherwise can't be reference)
376
+ # in equations
377
+ deleteat! (obseqs_sorted, findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted))
381
378
end
379
+ else
380
+ throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
382
381
end
383
382
end
384
383
@@ -399,26 +398,28 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
399
398
end
400
399
401
400
iv = only (independent_variables (sys))
402
- out_deps = _all_rhs_symbols (outeqs)
401
+
402
+ out_deps = _all_rhs_symbols (fixpoint_sub (outeqs, obs_subs))
403
403
fftype = _determine_fftype (out_deps, states, allinputs, params, iv)
404
404
405
405
# filter out unnecessary parameters
406
- var_deps = _all_rhs_symbols (eqs)
406
+ var_deps = _all_rhs_symbols (fixpoint_sub ( eqs, obs_subs) )
407
407
unused_params = Set (setdiff (params, (var_deps ∪ out_deps))) # do not exclud obs_deps
408
408
if verbose && ! isempty (unused_params)
409
409
@info " Parameters $(unused_params) do not appear in equations of f and g and will be marked as unused."
410
410
end
411
411
412
412
# TODO : explore Symbolcs/SymbolicUtils CSE
413
413
# now generate the actual functions
414
- formulas = [eq . rhs for eq in eqs]
415
- if ! isempty ( formulas)
414
+ if ! isempty ( eqs)
415
+ formulas = _get_formulas (eqs, obs_subs )
416
416
_, f_ip = build_function (formulas, states, inputss... , params, iv; cse= false , expression)
417
417
else
418
418
f_ip = nothing
419
419
end
420
420
421
- gformulas = [eq. rhs for eq in outeqs]
421
+ # find all observable assigments necessary for outeqs
422
+ gformulas = _get_formulas (outeqs, obs_subs)
422
423
gformargs = if fftype isa PureFeedForward
423
424
(inputss... , params, iv)
424
425
elseif fftype isa FeedForward
@@ -436,10 +437,13 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
436
437
MultipleOutputWrapper {fftype, length(outputss), typeof(_g_ip)} (_g_ip)
437
438
end
438
439
439
- # and the observed functions
440
- obsstates = [eq. lhs for eq in obseqs]
441
- obsformulas = [eq. rhs for eq in obseqs]
442
- _, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
440
+ obsstates = [eq. lhs for eq in obseqs_sorted]
441
+ if ! isempty (obsstates)
442
+ obsformulas = _get_formulas ([s ~ s for s in obsstates], obs_subs)
443
+ _, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
444
+ else
445
+ obsf_ip = nothing
446
+ end
443
447
444
448
return (;
445
449
f= f_ip, g= g_ip,
@@ -450,15 +454,60 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
450
454
obsstates,
451
455
fftype,
452
456
obsf = obsf_ip,
453
- equations= formulas ,
454
- outputeqs= Dict (Iterators . flatten (outputss) .=> gformulas) ,
455
- observed= Dict ( getname .(obsstates) .=> obsformulas) ,
457
+ equations= eqs ,
458
+ outputeqs= outeqs ,
459
+ observed= obseqs_sorted ,
456
460
odesystem= sys,
457
461
params,
458
462
unused_params
459
463
)
460
464
end
461
465
466
+ function _get_formulas (eqs, obs_subs)
467
+ # Bit hacky, were building a function like this,
468
+ # where all (necessary) obs and eqs are contained in the bgin block of the first output
469
+ # out[1] = begin
470
+ # obs1 = ...
471
+ # obs2 = ...
472
+ # ...
473
+ # state1 = ...
474
+ # state2 = ...
475
+ # ...
476
+ # state1 # ens up in out[1]
477
+ # end
478
+ # out[2] = state2
479
+ # ...
480
+ isempty (eqs) && return []
481
+ obsdeps = _collect_deps_on_obs ([eq. rhs for eq in eqs], obs_subs)
482
+ obs_assignments = [Assignment (k, v) for (k,v) in obs_subs if k ∈ obsdeps]
483
+
484
+ # implicit equations are not use via assigments, so we filter for e
485
+ eqs_assignments = [Assignment (eq. lhs, eq. rhs) for eq in eqs
486
+ if ! isequal (eq. lhs, eq. rhs) && ! isequal (eq. lhs, 0 )]
487
+ # since implicit eqs did not end up in assighmets, we use the rhs
488
+ out = [isequal (eq. lhs, 0 ) ? eq. rhs : eq. lhs for eq in eqs]
489
+
490
+ [Let (vcat (obs_assignments, eqs_assignments), out[1 ], false ), out[2 : end ]. .. ]
491
+ end
492
+ function _collect_deps_on_obs (terms, obs_subs)
493
+ deps = Set {Symbolic} ()
494
+ for term in terms
495
+ _collect_deps_on_obs! (deps, obs_subs, term)
496
+ end
497
+ deps
498
+ end
499
+ function _collect_deps_on_obs! (deps, obs_subs, term)
500
+ termdeps = get_variables (term)
501
+ for sym in termdeps
502
+ if haskey (obs_subs, sym)
503
+ # check recursively whether the observed depends on other observed
504
+ _collect_deps_on_obs! (deps, obs_subs, obs_subs[sym])
505
+ push! (deps, sym)
506
+ end
507
+ end
508
+ deps
509
+ end
510
+
462
511
function _determine_fftype (deps, states, allinputs, params, t)
463
512
if isempty (allinputs ∩ deps) # no ff path
464
513
if isempty (params ∩ deps) && ! (t ∈ deps) # no p nor t
0 commit comments