Skip to content

Commit 26ecd5e

Browse files
authored
Merge pull request #1315 from SciML/myb/odaeobs
Add term handling in build_observed_function
2 parents 85a668f + 503798e commit 26ecd5e

File tree

4 files changed

+35
-21
lines changed

4 files changed

+35
-21
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Dif
2121
get_structure, defaults, InvalidSystemException,
2222
ExtraEquationsSystemException,
2323
ExtraVariablesSystemException,
24-
get_postprocess_fbody
24+
get_postprocess_fbody, vars!
2525

2626
using ModelingToolkit.BipartiteGraphs
2727
using LightGraphs

src/structural_transformation/codegen.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -257,35 +257,48 @@ function find_solve_sequence(partitions, vars)
257257
end
258258

259259
function build_observed_function(
260-
sys, syms;
260+
sys, ts;
261261
expression=false,
262262
output_type=Array,
263263
checkbounds=true
264264
)
265265

266-
if (isscalar = !(syms isa Vector))
267-
syms = [syms]
266+
if (isscalar = !(ts isa AbstractVector))
267+
ts = [ts]
268268
end
269-
syms = value.(syms)
270-
syms_set = Set(syms)
269+
ts = Symbolics.scalarize.(value.(ts))
270+
271+
vars = Set()
272+
foreach(Base.Fix1(vars!, vars), ts)
273+
ivs = independent_variables(sys)
274+
dep_vars = collect(setdiff(vars, ivs))
275+
271276
s = structure(sys)
272277
@unpack partitions, fullvars, graph = s
273278
diffvars = map(i->fullvars[i], diffvars_range(s))
274279
algvars = map(i->fullvars[i], algvars_range(s))
275280

276-
required_algvars = Set(intersect(algvars, syms_set))
281+
required_algvars = Set(intersect(algvars, vars))
277282
obs = observed(sys)
278283
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
279284
# FIXME: this is a rather rough estimate of dependencies.
280285
maxidx = 0
281-
for (i, s) in enumerate(syms)
286+
sts = Set(states(sys))
287+
for (i, s) in enumerate(dep_vars)
282288
idx = get(observed_idx, s, nothing)
283-
idx === nothing && continue
289+
if idx === nothing
290+
if !(s in sts)
291+
throw(ArgumentError("$s is either an observed nor a state variable."))
292+
end
293+
continue
294+
end
284295
idx > maxidx && (maxidx = idx)
285296
end
297+
vs = Set()
286298
for idx in 1:maxidx
287-
vs = vars(obs[idx].rhs)
299+
vars!(vs, obs[idx].rhs)
288300
union!(required_algvars, intersect(algvars, vs))
301+
empty!(vs)
289302
end
290303

291304
varidxs = findall(x->x in required_algvars, fullvars)
@@ -301,12 +314,11 @@ function build_observed_function(
301314
solves = []
302315
end
303316

304-
output = map(syms) do sym
305-
if sym in required_algvars
306-
sym
307-
else
308-
obs[observed_idx[sym]].rhs
309-
end
317+
subs = []
318+
for sym in vars
319+
eqidx = get(observed_idx, sym, nothing)
320+
eqidx === nothing && continue
321+
push!(subs, sym obs[eqidx].rhs)
310322
end
311323
pre = get_postprocess_fbody(sys)
312324

@@ -321,8 +333,9 @@ function build_observed_function(
321333
[
322334
collect(Iterators.flatten(solves))
323335
map(eq -> eq.lhseq.rhs, obs[1:maxidx])
336+
subs
324337
],
325-
isscalar ? output[1] : MakeArray(output, output_type)
338+
isscalar ? ts[1] : MakeArray(ts, output_type)
326339
))
327340
) |> Code.toexpr
328341

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ function build_explicit_observed_function(
236236
ts = Symbolics.scalarize.(value.(ts))
237237

238238
vars = Set()
239-
syms = foreach(Base.Fix1(vars!, vars), ts)
239+
foreach(Base.Fix1(vars!, vars), ts)
240240
ivs = independent_variables(sys)
241241
dep_vars = collect(setdiff(vars, ivs))
242242

test/structural_transformation/tearing.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ sol2 = solve(ODEProblem{false}(
175175
), Tsit5(), tstops=sol1.t, adaptive=false)
176176
@test Array(sol1) Array(sol2) atol=1e-5
177177

178-
obs = build_observed_function(newdaesys, [z, y])
179-
@test map(u -> u[2], obs.(sol1.u, pr, sol1.t)) == first.(sol1.u)
180-
@test map(u -> sin(u[1]), obs.(sol1.u, pr, sol1.t)) + first.(sol1.u) pr[1]*sol1.t atol=1e-5
178+
@test sol1[x] == first.(sol1.u)
179+
@test sol1[y] == first.(sol1.u)
180+
@test sin.(sol1[z]) .+ sol1[y] pr[1] * sol1.t atol=1e-5
181+
@test sol1[sin(z) + y] sin.(sol1[z]) .+ sol1[y] rtol=1e-12
181182

182183
@test sol1[y, :] == sol1[x, :]
183184
@test (@. sin(sol1[z, :]) + sol1[y, :]) pr * sol1.t atol=1e-5

0 commit comments

Comments
 (0)