Skip to content

Commit 2c315ee

Browse files
committed
Add observedfun
1 parent defaeb7 commit 2c315ee

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,15 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
162162

163163
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
164164

165+
observedfun = let sys = sys, dict = Dict()
166+
function generated_observed(obsvar, u, p, t)
167+
obs = get!(dict, value(obsvar)) do
168+
build_explicit_observed_function(sys, obsvar)
169+
end
170+
obs(u, p, t)
171+
end
172+
end
173+
165174
ODEFunction{iip}(
166175
f,
167176
jac = _jac === nothing ? nothing : _jac,
@@ -170,6 +179,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
170179
jac_prototype = sparse ? similar(get_jac(sys)[],Float64) : nothing,
171180
syms = Symbol.(states(sys)),
172181
indepsym = Symbol(independent_variable(sys)),
182+
observed = observedfun,
173183
)
174184
end
175185

src/systems/diffeqs/odesystem.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,36 @@ function flatten(sys::ODESystem)
230230
end
231231

232232
ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...)
233+
234+
"""
235+
$(SIGNATURES)
236+
237+
Build the observed function assuming the observed equations are all explicit,
238+
i.e. there are no cycles or dependencies.
239+
"""
240+
function build_explicit_observed_function(
241+
sys, syms;
242+
expression=false,
243+
output_type=Array)
244+
245+
if (isscalar = !(syms isa Vector))
246+
syms = [syms]
247+
end
248+
syms = value.(syms)
249+
250+
obs = observed(sys)
251+
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
252+
output = map(sym->obs[observed_idx[sym]].rhs, syms)
253+
254+
ex = Func(
255+
[
256+
DestructuredArgs(states(sys))
257+
DestructuredArgs(parameters(sys))
258+
independent_variable(sys)
259+
],
260+
[],
261+
isscalar ? output[1] : MakeArray(output, output_type)
262+
) |> toexpr
263+
264+
expression ? ex : @RuntimeGeneratedFunction(ex)
265+
end

0 commit comments

Comments
 (0)