Skip to content

Commit 4a67046

Browse files
fix: detect observed variables and dependent parameters dependent on discrete parameters
1 parent 295edb6 commit 4a67046

File tree

4 files changed

+76
-33
lines changed

4 files changed

+76
-33
lines changed

src/systems/abstractsystem.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ end
730730
function has_observed_with_lhs(sys, sym)
731731
has_observed(sys) || return false
732732
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
733-
return any(isequal(sym), ic.observed_syms)
733+
return haskey(ic.observed_syms_to_timeseries, sym)
734734
else
735735
return any(isequal(sym), [eq.lhs for eq in observed(sys)])
736736
end
@@ -752,11 +752,16 @@ for traitT in [
752752
allsyms = vars(sym; op = Symbolics.Operator)
753753
for s in allsyms
754754
s = unwrap(s)
755-
if is_variable(sys, s) || is_independent_variable(sys, s) ||
756-
has_observed_with_lhs(sys, s)
755+
if is_variable(sys, s) || is_independent_variable(sys, s)
757756
push!(ts_idxs, ContinuousTimeseries())
758757
elseif is_timeseries_parameter(sys, s)
759758
push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx)
759+
elseif has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
760+
if (ts = get(ic.observed_syms_to_timeseries, s, nothing)) !== nothing
761+
union!(ts_idxs, ts)
762+
elseif (ts = get(ic.dependent_pars_to_timeseries, s, nothing)) !== nothing
763+
union!(ts_idxs, ts)
764+
end
760765
end
761766
end
762767
end

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,16 @@ function build_explicit_observed_function(sys, ts;
429429
ivs = independent_variables(sys)
430430
dep_vars = scalarize(setdiff(vars, ivs))
431431

432-
obs = param_only ? Equation[] : observed(sys)
432+
obs = observed(sys)
433+
if param_only
434+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
435+
obs = filter(obs) do eq
436+
!(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs])
437+
end
438+
else
439+
obs = Equation[]
440+
end
441+
end
433442

434443
cs = collect_constants(obs)
435444
if !isempty(cs) > 0

src/systems/index_cache.jl

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ const UnknownIndexMap = Dict{
3838
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
3939
const TunableIndexMap = Dict{BasicSymbolic,
4040
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
41+
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}
4142

4243
struct IndexCache
4344
unknown_idx::UnknownIndexMap
@@ -48,8 +49,9 @@ struct IndexCache
4849
tunable_idx::TunableIndexMap
4950
constant_idx::ParamIndexMap
5051
nonnumeric_idx::NonnumericMap
51-
observed_syms::Set{BasicSymbolic}
52-
dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}}
52+
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
53+
dependent_pars_to_timeseries::Dict{
54+
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
5355
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5456
tunable_buffer_size::BufferTemplate
5557
constant_buffer_sizes::Vector{BufferTemplate}
@@ -91,20 +93,6 @@ function IndexCache(sys::AbstractSystem)
9193
end
9294
end
9395

94-
observed_syms = Set{BasicSymbolic}()
95-
for eq in observed(sys)
96-
if symbolic_type(eq.lhs) != NotSymbolic()
97-
sym = eq.lhs
98-
ttsym = default_toterm(sym)
99-
rsym = renamespace(sys, sym)
100-
rttsym = renamespace(sys, ttsym)
101-
push!(observed_syms, sym)
102-
push!(observed_syms, ttsym)
103-
push!(observed_syms, rsym)
104-
push!(observed_syms, rttsym)
105-
end
106-
end
107-
10896
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
10997
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
11098
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
@@ -267,38 +255,68 @@ function IndexCache(sys::AbstractSystem)
267255
end
268256
end
269257

270-
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
271-
keys(const_idxs), keys(nonnumeric_idxs),
272-
observed_syms, independent_variable_symbols(sys)))
273-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
274-
symbol_to_variable[getname(sym)] = sym
275-
end
276-
end
277-
278-
dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}()
258+
dependent_pars_to_timeseries = Dict{
259+
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()
279260

280261
for eq in parameter_dependencies(sys)
281262
sym = eq.lhs
263+
vs = vars(eq.rhs)
264+
timeseries = TimeseriesSetType()
265+
for v in vs
266+
if (idx = get(disc_idxs, v, nothing)) !== nothing
267+
push!(timeseries, idx.clock_idx)
268+
end
269+
end
282270
ttsym = default_toterm(sym)
283271
rsym = renamespace(sys, sym)
284272
rttsym = renamespace(sys, ttsym)
285-
for s in [sym, ttsym, rsym, rttsym]
286-
push!(dependent_pars, s)
273+
for s in (sym, ttsym, rsym, rttsym)
274+
dependent_pars_to_timeseries[s] = timeseries
287275
if hasname(s) && (!iscall(s) || operation(s) != getindex)
288276
symbol_to_variable[getname(s)] = sym
289277
end
290278
end
291279
end
292280

281+
observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}()
282+
for eq in observed(sys)
283+
if symbolic_type(eq.lhs) != NotSymbolic()
284+
sym = eq.lhs
285+
vs = vars(eq.rhs)
286+
timeseries = TimeseriesSetType()
287+
for v in vs
288+
if (idx = get(disc_idxs, v, nothing)) !== nothing
289+
push!(timeseries, idx.clock_idx)
290+
elseif haskey(unk_idxs, v) || haskey(observed_syms_to_timeseries, v)
291+
push!(timeseries, ContinuousTimeseries())
292+
end
293+
end
294+
ttsym = default_toterm(sym)
295+
rsym = renamespace(sys, sym)
296+
rttsym = renamespace(sys, ttsym)
297+
for s in (sym, ttsym, rsym, rttsym)
298+
observed_syms_to_timeseries[s] = timeseries
299+
end
300+
end
301+
end
302+
303+
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
304+
keys(const_idxs), keys(nonnumeric_idxs),
305+
keys(observed_syms_to_timeseries), independent_variable_symbols(sys)))
306+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
307+
symbol_to_variable[getname(sym)] = sym
308+
end
309+
end
310+
293311
return IndexCache(
294312
unk_idxs,
295313
disc_idxs,
296314
callback_to_clocks,
297315
tunable_idxs,
298316
const_idxs,
299317
nonnumeric_idxs,
300-
observed_syms,
301-
dependent_pars,
318+
observed_syms_to_timeseries,
319+
dependent_pars_to_timeseries,
302320
disc_buffer_templates,
303321
BufferTemplate(Real, tunable_buffer_size),
304322
const_buffer_sizes,

test/odesystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,3 +1442,14 @@ end
14421442
end
14431443
end
14441444
end
1445+
1446+
# https://github.com/SciML/SciMLBase.jl/issues/786
1447+
@testset "Observed variables dependent on discrete parameters" begin
1448+
@variables x(t) obs(t)
1449+
@parameters c(t)
1450+
@mtkbuild sys = ODESystem(
1451+
[D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
1452+
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
1453+
sol = solve(prob, Tsit5())
1454+
@test sol[obs] 1:7
1455+
end

0 commit comments

Comments
 (0)