@@ -353,37 +353,35 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
353353 verbose && @info " New States:" states
354354 end
355355
356- # find the output equations, this might remove them from obseqs!
356+ # find the output equations, this might remove them from obseqs_sorted (obs_subs stays intact)
357357 outeqs = Equation[]
358358 for out in Iterators. flatten (outputss)
359359 if out ∈ Set (states)
360360 push! (outeqs, out ~ out)
361- else
362- idx = findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted)
363- if isnothing (idx)
364- throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
365- end
366- eq = obseqs_sorted[idx]
367- if ! isempty (rhs_differentials (eq))
368- println (obs_subs[out])
369- throw (ArgumentError (" Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials (eq)) " ))
370- end
371- deleteat! (obseqs_sorted, idx)
372-
373- if ff_to_constraint && ! isempty (get_variables (eq. rhs) ∩ allinputs)
361+ elseif out ∈ keys (obs_subs)
362+ # if its a observed, we need to check for ff behavior
363+ fulleq = out ~ fixpoint_sub (obs_subs[out], obs_subs)
364+ if ff_to_constraint && ! isempty (get_variables (fulleq. rhs) ∩ allinputs)
374365 verbose && @info " Output $out would lead to FF in g, promote to state instead."
375- push! (eqs, 0 ~ eq. lhs - eq. rhs)
366+ # not observed anymore, delete from observed and put in equations
367+ push! (eqs, 0 ~ out - obs_subs[out])
376368 push! (states, eq. lhs)
377- push! (outeqs, eq. lhs ~ eq. lhs)
378- else
379- push! (outeqs, eq)
369+ push! (outeqs, out ~ out)
370+
371+ deleteat! (obseqs_sorted, findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted))
372+ delete! (obs_subs, out)
373+ push! (outeqs, out ~ out)
374+ else # "normal" observed state
375+ push! (outeqs, out ~ obs_subs[out])
376+ # delete from obs equations but *not* from obs_subs (otherwise can't be reference)
377+ # in equations
378+ deleteat! (obseqs_sorted, findfirst (eq -> isequal (eq. lhs, out), obseqs_sorted))
380379 end
380+ else
381+ throw (ArgumentError (" Output $out was neither foundin states nor in observed equations." ))
381382 end
382383 end
383384
384- # obseqs might have changed in block above
385- obs_subs = OrderedDict (eq. lhs => eq. rhs for eq in obseqs_sorted)
386-
387385 # generate mass matrix (this might change the equations)
388386 mass_matrix = begin
389387 # equations of form o = f(...) have to be transformed to 0 = f(...) - o
@@ -418,7 +416,6 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
418416 formulas = _get_formulas (eqs, obs_subs)
419417 _, f_ip = build_function (formulas, states, inputss... , params, iv; cse= false , expression)
420418 else
421- formulas = []
422419 f_ip = nothing
423420 end
424421
@@ -441,10 +438,13 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
441438 MultipleOutputWrapper {fftype, length(outputss), typeof(_g_ip)} (_g_ip)
442439 end
443440
444- # and the observed functions
445441 obsstates = [eq. lhs for eq in obseqs_sorted]
446- obsformulas = [eq. rhs for eq in obseqs_sorted]
447- _, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
442+ if ! isempty (obsstates)
443+ obsformulas = _get_formulas ([s ~ s for s in obsstates], obs_subs)
444+ _, obsf_ip = build_function (obsformulas, states, inputss... , params, iv; cse= false , expression)
445+ else
446+ obsf_ip = nothing
447+ end
448448
449449 return (;
450450 f= f_ip, g= g_ip,
@@ -455,25 +455,40 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
455455 obsstates,
456456 fftype,
457457 obsf = obsf_ip,
458- equations= formulas ,
459- outputeqs= Dict (Iterators . flatten (outputss) .=> gformulas) ,
460- observed= Dict ( getname .(obsstates) .=> obsformulas) ,
458+ equations= eqs ,
459+ outputeqs= outeqs ,
460+ observed= obseqs_sorted ,
461461 odesystem= sys,
462462 params,
463463 unused_params
464464 )
465465end
466466
467467function _get_formulas (eqs, obs_subs)
468- rhss = [eq. rhs for eq in eqs]
469- obsdeps = _collect_deps_on_obs (rhss, obs_subs)
470- if isempty (obsdeps)
471- return rhss
472- else
473- # ensure that the ordering is still correct
474- obs_assignments = [Assignment (k, v) for (k,v) in obs_subs if k ∈ obsdeps]
475- return [Let (obs_assignments, rhss[1 ], false ), rhss[2 : end ]. .. ]
476- end
468+ # Bit hacky, were building a function like this,
469+ # where all (necessary) obs and eqs are contained in the bgin block of the first output
470+ # out[1] = begin
471+ # obs1 = ...
472+ # obs2 = ...
473+ # ...
474+ # state1 = ...
475+ # state2 = ...
476+ # ...
477+ # state1 # ens up in out[1]
478+ # end
479+ # out[2] = state2
480+ # ...
481+ isempty (eqs) && return []
482+ obsdeps = _collect_deps_on_obs ([eq. rhs for eq in eqs], obs_subs)
483+ obs_assignments = [Assignment (k, v) for (k,v) in obs_subs if k ∈ obsdeps]
484+
485+ # implicit equations are not use via assigments, so we filter for e
486+ eqs_assignments = [Assignment (eq. lhs, eq. rhs) for eq in eqs
487+ if ! isequal (eq. lhs, eq. rhs) && ! isequal (eq. lhs, 0 )]
488+ # since implicit eqs did not end up in assighmets, we use the rhs
489+ out = [isequal (eq. lhs, 0 ) ? eq. rhs : eq. lhs for eq in eqs]
490+
491+ [Let (vcat (obs_assignments, eqs_assignments), out[1 ], false ), out[2 : end ]. .. ]
477492end
478493function _collect_deps_on_obs (terms, obs_subs)
479494 deps = Set {Symbolic} ()
0 commit comments