Skip to content

Commit 1079c7a

Browse files
committed
fix handling of outputs which are observed
1 parent d713c34 commit 1079c7a

File tree

1 file changed

+52
-37
lines changed

1 file changed

+52
-37
lines changed

ext/NetworkDynamicsMTKExt.jl

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -353,37 +353,35 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
353353
verbose && @info "New States:" states
354354
end
355355

356-
# find the output equations, this might remove them from obseqs!
356+
# find the output equations, this might remove them from obseqs_sorted (obs_subs stays intact)
357357
outeqs = Equation[]
358358
for out in Iterators.flatten(outputss)
359359
if out Set(states)
360360
push!(outeqs, out ~ out)
361-
else
362-
idx = findfirst(eq -> isequal(eq.lhs, out), obseqs_sorted)
363-
if isnothing(idx)
364-
throw(ArgumentError("Output $out was neither foundin states nor in observed equations."))
365-
end
366-
eq = obseqs_sorted[idx]
367-
if !isempty(rhs_differentials(eq))
368-
println(obs_subs[out])
369-
throw(ArgumentError("Algebraic FF equation for output $out contains differentials in the RHS: $(rhs_differentials(eq))"))
370-
end
371-
deleteat!(obseqs_sorted, idx)
372-
373-
if ff_to_constraint && !isempty(get_variables(eq.rhs) allinputs)
361+
elseif out keys(obs_subs)
362+
# if its a observed, we need to check for ff behavior
363+
fulleq = out ~ fixpoint_sub(obs_subs[out], obs_subs)
364+
if ff_to_constraint && !isempty(get_variables(fulleq.rhs) allinputs)
374365
verbose && @info "Output $out would lead to FF in g, promote to state instead."
375-
push!(eqs, 0 ~ eq.lhs - eq.rhs)
366+
# not observed anymore, delete from observed and put in equations
367+
push!(eqs, 0 ~ out - obs_subs[out])
376368
push!(states, eq.lhs)
377-
push!(outeqs, eq.lhs ~ eq.lhs)
378-
else
379-
push!(outeqs, eq)
369+
push!(outeqs, out ~ out)
370+
371+
deleteat!(obseqs_sorted, findfirst(eq -> isequal(eq.lhs, out), obseqs_sorted))
372+
delete!(obs_subs, out)
373+
push!(outeqs, out ~ out)
374+
else # "normal" observed state
375+
push!(outeqs, out ~ obs_subs[out])
376+
# delete from obs equations but *not* from obs_subs (otherwise can't be reference)
377+
# in equations
378+
deleteat!(obseqs_sorted, findfirst(eq -> isequal(eq.lhs, out), obseqs_sorted))
380379
end
380+
else
381+
throw(ArgumentError("Output $out was neither foundin states nor in observed equations."))
381382
end
382383
end
383384

384-
# obseqs might have changed in block above
385-
obs_subs = OrderedDict(eq.lhs => eq.rhs for eq in obseqs_sorted)
386-
387385
# generate mass matrix (this might change the equations)
388386
mass_matrix = begin
389387
# equations of form o = f(...) have to be transformed to 0 = f(...) - o
@@ -418,7 +416,6 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
418416
formulas = _get_formulas(eqs, obs_subs)
419417
_, f_ip = build_function(formulas, states, inputss..., params, iv; cse=false, expression)
420418
else
421-
formulas = []
422419
f_ip = nothing
423420
end
424421

@@ -441,10 +438,13 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
441438
MultipleOutputWrapper{fftype, length(outputss), typeof(_g_ip)}(_g_ip)
442439
end
443440

444-
# and the observed functions
445441
obsstates = [eq.lhs for eq in obseqs_sorted]
446-
obsformulas = [eq.rhs for eq in obseqs_sorted]
447-
_, obsf_ip = build_function(obsformulas, states, inputss..., params, iv; cse=false, expression)
442+
if !isempty(obsstates)
443+
obsformulas = _get_formulas([s ~ s for s in obsstates], obs_subs)
444+
_, obsf_ip = build_function(obsformulas, states, inputss..., params, iv; cse=false, expression)
445+
else
446+
obsf_ip = nothing
447+
end
448448

449449
return (;
450450
f=f_ip, g=g_ip,
@@ -455,25 +455,40 @@ function generate_io_function(_sys, inputss::Tuple, outputss::Tuple;
455455
obsstates,
456456
fftype,
457457
obsf = obsf_ip,
458-
equations=formulas,
459-
outputeqs=Dict(Iterators.flatten(outputss) .=> gformulas),
460-
observed=Dict(getname.(obsstates) .=> obsformulas),
458+
equations=eqs,
459+
outputeqs=outeqs,
460+
observed=obseqs_sorted,
461461
odesystem=sys,
462462
params,
463463
unused_params
464464
)
465465
end
466466

467467
function _get_formulas(eqs, obs_subs)
468-
rhss = [eq.rhs for eq in eqs]
469-
obsdeps = _collect_deps_on_obs(rhss, obs_subs)
470-
if isempty(obsdeps)
471-
return rhss
472-
else
473-
# ensure that the ordering is still correct
474-
obs_assignments = [Assignment(k, v) for (k,v) in obs_subs if k obsdeps]
475-
return [Let(obs_assignments, rhss[1], false), rhss[2:end]...]
476-
end
468+
# Bit hacky, were building a function like this,
469+
# where all (necessary) obs and eqs are contained in the bgin block of the first output
470+
# out[1] = begin
471+
# obs1 = ...
472+
# obs2 = ...
473+
# ...
474+
# state1 = ...
475+
# state2 = ...
476+
# ...
477+
# state1 # ens up in out[1]
478+
# end
479+
# out[2] = state2
480+
# ...
481+
isempty(eqs) && return []
482+
obsdeps = _collect_deps_on_obs([eq.rhs for eq in eqs], obs_subs)
483+
obs_assignments = [Assignment(k, v) for (k,v) in obs_subs if k obsdeps]
484+
485+
# implicit equations are not use via assigments, so we filter for e
486+
eqs_assignments = [Assignment(eq.lhs, eq.rhs) for eq in eqs
487+
if !isequal(eq.lhs, eq.rhs) && !isequal(eq.lhs, 0)]
488+
# since implicit eqs did not end up in assighmets, we use the rhs
489+
out = [isequal(eq.lhs, 0) ? eq.rhs : eq.lhs for eq in eqs]
490+
491+
[Let(vcat(obs_assignments, eqs_assignments), out[1], false), out[2:end]...]
477492
end
478493
function _collect_deps_on_obs(terms, obs_subs)
479494
deps = Set{Symbolic}()

0 commit comments

Comments
 (0)