Skip to content

Commit e0928ee

Browse files
committed
Add term handling in build_observed_function
1 parent 85a668f commit e0928ee

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Dif
2121
get_structure, defaults, InvalidSystemException,
2222
ExtraEquationsSystemException,
2323
ExtraVariablesSystemException,
24-
get_postprocess_fbody
24+
get_postprocess_fbody, vars!
2525

2626
using ModelingToolkit.BipartiteGraphs
2727
using LightGraphs

src/structural_transformation/codegen.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -257,35 +257,48 @@ function find_solve_sequence(partitions, vars)
257257
end
258258

259259
function build_observed_function(
260-
sys, syms;
260+
sys, ts;
261261
expression=false,
262262
output_type=Array,
263263
checkbounds=true
264264
)
265265

266-
if (isscalar = !(syms isa Vector))
267-
syms = [syms]
266+
if (isscalar = !(ts isa AbstractVector))
267+
ts = [ts]
268268
end
269-
syms = value.(syms)
270-
syms_set = Set(syms)
269+
ts = Symbolics.scalarize.(value.(ts))
270+
271+
vars = Set()
272+
foreach(Base.Fix1(vars!, vars), ts)
273+
ivs = independent_variables(sys)
274+
dep_vars = collect(setdiff(vars, ivs))
275+
271276
s = structure(sys)
272277
@unpack partitions, fullvars, graph = s
273278
diffvars = map(i->fullvars[i], diffvars_range(s))
274279
algvars = map(i->fullvars[i], algvars_range(s))
275280

276-
required_algvars = Set(intersect(algvars, syms_set))
281+
required_algvars = Set(intersect(algvars, vars))
277282
obs = observed(sys)
278283
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
279284
# FIXME: this is a rather rough estimate of dependencies.
280285
maxidx = 0
281-
for (i, s) in enumerate(syms)
286+
sts = Set(states(sys))
287+
for (i, s) in enumerate(dep_vars)
282288
idx = get(observed_idx, s, nothing)
283-
idx === nothing && continue
289+
if idx === nothing
290+
if !(s in sts)
291+
throw(ArgumentError("$s is either an observed nor a state variable."))
292+
end
293+
continue
294+
end
284295
idx > maxidx && (maxidx = idx)
285296
end
297+
vs = Set()
286298
for idx in 1:maxidx
287-
vs = vars(obs[idx].rhs)
299+
vars!(vs, obs[idx].rhs)
288300
union!(required_algvars, intersect(algvars, vs))
301+
empty!(vs)
289302
end
290303

291304
varidxs = findall(x->x in required_algvars, fullvars)
@@ -301,12 +314,11 @@ function build_observed_function(
301314
solves = []
302315
end
303316

304-
output = map(syms) do sym
305-
if sym in required_algvars
306-
sym
307-
else
308-
obs[observed_idx[sym]].rhs
309-
end
317+
subs = []
318+
for sym in vars
319+
eqidx = get(observed_idx, sym, nothing)
320+
eqidx === nothing && continue
321+
push!(subs, sym obs[eqidx].rhs)
310322
end
311323
pre = get_postprocess_fbody(sys)
312324

@@ -321,8 +333,9 @@ function build_observed_function(
321333
[
322334
collect(Iterators.flatten(solves))
323335
map(eq -> eq.lhseq.rhs, obs[1:maxidx])
336+
subs
324337
],
325-
isscalar ? output[1] : MakeArray(output, output_type)
338+
isscalar ? ts[1] : MakeArray(ts, output_type)
326339
))
327340
) |> Code.toexpr
328341

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ function build_explicit_observed_function(
236236
ts = Symbolics.scalarize.(value.(ts))
237237

238238
vars = Set()
239-
syms = foreach(Base.Fix1(vars!, vars), ts)
239+
foreach(Base.Fix1(vars!, vars), ts)
240240
ivs = independent_variables(sys)
241241
dep_vars = collect(setdiff(vars, ivs))
242242

0 commit comments

Comments
 (0)