Skip to content

Commit 7205307

Browse files
committed
Remove old comments and allow solution indexing let sol[2x[1]]
1 parent f2abbbf commit 7205307

File tree

2 files changed

+25
-41
lines changed

2 files changed

+25
-41
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
261261

262262
obs = observed(sys)
263263
observedfun = if steady_state
264-
isempty(obs) ? SciMLBase.DEFAULT_OBSERVED_NO_TIME : let sys = sys, dict = Dict()
264+
let sys = sys, dict = Dict()
265265
function generated_observed(obsvar, u, p, t=Inf)
266266
obs = get!(dict, value(obsvar)) do
267267
build_explicit_observed_function(sys, obsvar)
@@ -270,7 +270,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
270270
end
271271
end
272272
else
273-
isempty(obs) ? SciMLBase.DEFAULT_OBSERVED : let sys = sys, dict = Dict()
273+
let sys = sys, dict = Dict()
274274
function generated_observed(obsvar, u, p, t)
275275
obs = get!(dict, value(obsvar)) do
276276
build_explicit_observed_function(sys, obsvar; checkbounds=checkbounds)
@@ -338,16 +338,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
338338
# TODO: Jacobian sparsity / sparse Jacobian / dense Jacobian
339339

340340
#=
341-
observedfun = let sys = sys, dict = Dict()
342341
# TODO: We don't have enought information to reconstruct arbitrary state
343-
# in general from `(u, p, t)`, e.g. `a ~ D(x)`.
344-
function generated_observed(obsvar, u, p, t)
345-
obs = get!(dict, value(obsvar)) do
346-
build_explicit_observed_function(sys, obsvar)
347-
end
348-
obs(u, p, t)
349-
end
350-
end
351342
=#
352343

353344
DAEFunction{iip}(
@@ -394,23 +385,6 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
394385
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)
395386

396387
dict = Dict()
397-
#=
398-
observedfun = if steady_state
399-
:(function generated_observed(obsvar, u, p, t=Inf)
400-
obs = get!($dict, value(obsvar)) do
401-
build_explicit_observed_function($sys, obsvar)
402-
end
403-
obs(u, p, t)
404-
end)
405-
else
406-
:(function generated_observed(obsvar, u, p, t)
407-
obs = get!($dict, value(obsvar)) do
408-
build_explicit_observed_function($sys, obsvar)
409-
end
410-
obs(u, p, t)
411-
end)
412-
end
413-
=#
414388

415389
fsym = gensym(:f)
416390
_f = :($fsym = $ODEFunctionClosure($f_oop, $f_iip))

src/systems/diffeqs/odesystem.jl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,42 +225,52 @@ Build the observed function assuming the observed equations are all explicit,
225225
i.e. there are no cycles.
226226
"""
227227
function build_explicit_observed_function(
228-
sys, syms;
228+
sys, ts;
229229
expression=false,
230230
output_type=Array,
231231
checkbounds=true)
232232

233-
if (isscalar = !(syms isa Vector))
234-
syms = [syms]
233+
if (isscalar = !(ts isa AbstractVector))
234+
ts = [ts]
235235
end
236-
syms = value.(syms)
236+
ts = Symbolics.scalarize.(value.(ts))
237+
238+
vars = Set()
239+
syms = foreach(Base.Fix1(vars!, vars), ts)
240+
ivs = independent_variables(sys)
241+
dep_vars = collect(setdiff(vars, ivs))
237242

238243
obs = observed(sys)
244+
sts = Set(states(sys))
239245
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
240-
output = similar(syms, Any)
241-
# FIXME: this is a rather rough estimate of dependencies.
246+
247+
# FIXME: This is a rather rough estimate of dependencies. We assume
248+
# the expression depends on everything before the `maxidx`.
242249
maxidx = 0
243-
for (i, s) in enumerate(syms)
250+
for (i, s) in enumerate(dep_vars)
244251
idx = get(observed_idx, s, nothing)
245-
idx === nothing && throw(ArgumentError("$s is not an observed variable."))
252+
if idx === nothing
253+
if !(s in sts)
254+
throw(ArgumentError("$s is either an observed nor a state variable."))
255+
end
256+
continue
257+
end
246258
idx > maxidx && (maxidx = idx)
247-
output[i] = obs[idx].rhs
248259
end
260+
obsexprs = map(eq -> eq.lhseq.rhs, obs[1:maxidx])
249261

250262
dvs = DestructuredArgs(states(sys), inbounds=!checkbounds)
251263
ps = DestructuredArgs(parameters(sys), inbounds=!checkbounds)
252-
ivs = independent_variables(sys)
253264
args = [dvs, ps, ivs...]
254265
pre = get_postprocess_fbody(sys)
255266

256267
ex = Func(
257268
args, [],
258269
pre(Let(
259-
map(eq -> eq.lhseq.rhs, obs[1:maxidx]),
260-
isscalar ? output[1] : MakeArray(output, output_type)
270+
obsexprs,
271+
isscalar ? ts[1] : MakeArray(ts, output_type)
261272
))
262273
) |> toexpr
263-
264274
expression ? ex : @RuntimeGeneratedFunction(ex)
265275
end
266276

0 commit comments

Comments
 (0)