Skip to content

Commit aec18ba

Browse files
author
Brad Carman
committed
support for observables
1 parent 4b66189 commit aec18ba

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

src/systems/abstractsystem.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,25 +767,32 @@ function toexpr(sys::AbstractSystem)
767767
psname = gensym(:ps)
768768
ps = parameters(sys)
769769
push_vars!(stmt, psname, Symbol("@parameters"), ps)
770+
obs = observed(sys)
771+
obsvars = [o.lhs for o in obs]
772+
obsvarsname = gensym(:obs)
773+
push_vars!(stmt, obsvarsname, Symbol("@variables"), obsvars)
770774

771775
var2name = Dict{Any, Symbol}()
772-
for v in Iterators.flatten((sts, ps))
776+
for v in Iterators.flatten((sts, ps, obsvars))
773777
var2name[v] = getname(v)
774778
end
775779

776-
eqs_name = push_eqs!(stmt, equations(sys), var2name)
780+
eqs_name = push_eqs!(stmt, full_equations(sys), var2name)
777781
defs_name = push_defaults!(stmt, defaults(sys), var2name)
782+
obs_name = push_eqs!(stmt, obs, var2name)
778783

779784
if sys isa ODESystem
780785
iv = get_iv(sys)
781786
ivname = gensym(:iv)
782787
push!(stmt, :($ivname = (@variables $(getname(iv)))[1]))
783788
push!(stmt,
784789
:($ODESystem($eqs_name, $ivname, $stsname, $psname; defaults = $defs_name,
790+
observed = $obs_name,
785791
name = $name, checks = false)))
786792
elseif sys isa NonlinearSystem
787793
push!(stmt,
788794
:($NonlinearSystem($eqs_name, $stsname, $psname; defaults = $defs_name,
795+
observed = $obs_name,
789796
name = $name, checks = false)))
790797
end
791798

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
516516
sparse = false, simplify = false,
517517
steady_state = false,
518518
sparsity = false,
519+
observedfun_exp = nothing,
519520
kwargs...) where {iip}
520521
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
521522

@@ -567,7 +568,8 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
567568
syms = $(Symbol.(states(sys))),
568569
indepsym = $(QuoteNode(Symbol(get_iv(sys)))),
569570
paramsyms = $(Symbol.(parameters(sys))),
570-
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing))
571+
sparsity = $(sparsity ? jacobian_sparsity(sys) : nothing),
572+
observed = $observedfun_exp)
571573
end
572574
!linenumbers ? striplines(ex) : ex
573575
end

test/serialization.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, SciMLBase, Serialization
1+
using ModelingToolkit, SciMLBase, Serialization, OrdinaryDiffEq
22

33
@parameters t
44
@variables x(t)
@@ -28,3 +28,41 @@ write(io, rc_model)
2828
str = String(take!(io))
2929
sys = include_string(@__MODULE__, str)
3030
@test sys == flatten(rc_model) # this actually kind of works, but the variables would have different identities.
31+
32+
# check answer
33+
ss = structural_simplify(rc_model)
34+
all_obs = [o.lhs for o in observed(ss)]
35+
prob = ODEProblem(ss, [], (0, 0.1))
36+
sol = solve(prob, ImplicitEuler())
37+
38+
## Check ODESystem with Observables ----------
39+
ss_exp = ModelingToolkit.toexpr(ss)
40+
ss_ = eval(ss_exp)
41+
prob_ = ODEProblem(ss_, [], (0, 0.1))
42+
sol_ = solve(prob_, ImplicitEuler())
43+
@test sol[all_obs] == sol_[all_obs]
44+
45+
## Check ODEProblemExpr with Observables -----------
46+
47+
# build the observable function expression
48+
obs_exps = []
49+
for var in all_obs
50+
f = ModelingToolkit.build_explicit_observed_function(ss, var; expression = true)
51+
sym = ModelingToolkit.getname(var) |> string
52+
ex = :(if name == Symbol($sym)
53+
return $f(u0, p, t)
54+
end)
55+
push!(obs_exps, ex)
56+
end
57+
# observedfun expression for ODEFunctionExpr
58+
observedfun_exp = :(function (var, u0, p, t)
59+
name = ModelingToolkit.getname(var)
60+
$(obs_exps...)
61+
end)
62+
63+
# ODEProblemExpr with observedfun_exp included
64+
probexpr = ODEProblemExpr{true}(ss, [], (0, 0.1); observedfun_exp);
65+
prob_obs = eval(probexpr)
66+
sol_obs = solve(prob_obs, ImplicitEuler())
67+
68+
@test sol_obs[all_obs] == sol[all_obs]

0 commit comments

Comments
 (0)