Skip to content

Commit 894c135

Browse files
fix: detect observed variables and dependent parameters dependent on discrete parameters
1 parent ac38df6 commit 894c135

File tree

4 files changed

+77
-34
lines changed

4 files changed

+77
-34
lines changed

src/systems/abstractsystem.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ end
731731
function has_observed_with_lhs(sys, sym)
732732
has_observed(sys) || return false
733733
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
734-
return any(isequal(sym), ic.observed_syms)
734+
return haskey(ic.observed_syms_to_timeseries, sym)
735735
else
736736
return any(isequal(sym), [eq.lhs for eq in observed(sys)])
737737
end
@@ -740,7 +740,7 @@ end
740740
function has_parameter_dependency_with_lhs(sys, sym)
741741
has_parameter_dependencies(sys) || return false
742742
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
743-
return any(isequal(sym), ic.dependent_pars)
743+
return haskey(ic.dependent_pars_to_timeseries, unwrap(sym))
744744
else
745745
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
746746
end
@@ -762,11 +762,16 @@ for traitT in [
762762
allsyms = vars(sym; op = Symbolics.Operator)
763763
for s in allsyms
764764
s = unwrap(s)
765-
if is_variable(sys, s) || is_independent_variable(sys, s) ||
766-
has_observed_with_lhs(sys, s)
765+
if is_variable(sys, s) || is_independent_variable(sys, s)
767766
push!(ts_idxs, ContinuousTimeseries())
768767
elseif is_timeseries_parameter(sys, s)
769768
push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx)
769+
elseif has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
770+
if (ts = get(ic.observed_syms_to_timeseries, s, nothing)) !== nothing
771+
union!(ts_idxs, ts)
772+
elseif (ts = get(ic.dependent_pars_to_timeseries, s, nothing)) !== nothing
773+
union!(ts_idxs, ts)
774+
end
770775
end
771776
end
772777
end

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,16 @@ function build_explicit_observed_function(sys, ts;
490490
ivs = independent_variables(sys)
491491
dep_vars = scalarize(setdiff(vars, ivs))
492492

493-
obs = param_only ? Equation[] : observed(sys)
493+
obs = observed(sys)
494+
if param_only
495+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
496+
obs = filter(obs) do eq
497+
!(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs])
498+
end
499+
else
500+
obs = Equation[]
501+
end
502+
end
494503

495504
cs = collect_constants(obs)
496505
if !isempty(cs) > 0

src/systems/index_cache.jl

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ const UnknownIndexMap = Dict{
3838
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
3939
const TunableIndexMap = Dict{BasicSymbolic,
4040
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
41+
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}
4142

4243
struct IndexCache
4344
unknown_idx::UnknownIndexMap
@@ -48,8 +49,9 @@ struct IndexCache
4849
tunable_idx::TunableIndexMap
4950
constant_idx::ParamIndexMap
5051
nonnumeric_idx::NonnumericMap
51-
observed_syms::Set{BasicSymbolic}
52-
dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}}
52+
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
53+
dependent_pars_to_timeseries::Dict{
54+
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
5355
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5456
tunable_buffer_size::BufferTemplate
5557
constant_buffer_sizes::Vector{BufferTemplate}
@@ -91,20 +93,6 @@ function IndexCache(sys::AbstractSystem)
9193
end
9294
end
9395

94-
observed_syms = Set{BasicSymbolic}()
95-
for eq in observed(sys)
96-
if symbolic_type(eq.lhs) != NotSymbolic()
97-
sym = eq.lhs
98-
ttsym = default_toterm(sym)
99-
rsym = renamespace(sys, sym)
100-
rttsym = renamespace(sys, ttsym)
101-
push!(observed_syms, sym)
102-
push!(observed_syms, ttsym)
103-
push!(observed_syms, rsym)
104-
push!(observed_syms, rttsym)
105-
end
106-
end
107-
10896
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
10997
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
11098
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
@@ -267,38 +255,68 @@ function IndexCache(sys::AbstractSystem)
267255
end
268256
end
269257

270-
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
271-
keys(const_idxs), keys(nonnumeric_idxs),
272-
observed_syms, independent_variable_symbols(sys)))
273-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
274-
symbol_to_variable[getname(sym)] = sym
275-
end
276-
end
277-
278-
dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}()
258+
dependent_pars_to_timeseries = Dict{
259+
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()
279260

280261
for eq in parameter_dependencies(sys)
281262
sym = eq.lhs
263+
vs = vars(eq.rhs)
264+
timeseries = TimeseriesSetType()
265+
for v in vs
266+
if (idx = get(disc_idxs, v, nothing)) !== nothing
267+
push!(timeseries, idx.clock_idx)
268+
end
269+
end
282270
ttsym = default_toterm(sym)
283271
rsym = renamespace(sys, sym)
284272
rttsym = renamespace(sys, ttsym)
285-
for s in [sym, ttsym, rsym, rttsym]
286-
push!(dependent_pars, s)
273+
for s in (sym, ttsym, rsym, rttsym)
274+
dependent_pars_to_timeseries[s] = timeseries
287275
if hasname(s) && (!iscall(s) || operation(s) != getindex)
288276
symbol_to_variable[getname(s)] = sym
289277
end
290278
end
291279
end
292280

281+
observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}()
282+
for eq in observed(sys)
283+
if symbolic_type(eq.lhs) != NotSymbolic()
284+
sym = eq.lhs
285+
vs = vars(eq.rhs)
286+
timeseries = TimeseriesSetType()
287+
for v in vs
288+
if (idx = get(disc_idxs, v, nothing)) !== nothing
289+
push!(timeseries, idx.clock_idx)
290+
elseif haskey(unk_idxs, v) || haskey(observed_syms_to_timeseries, v)
291+
push!(timeseries, ContinuousTimeseries())
292+
end
293+
end
294+
ttsym = default_toterm(sym)
295+
rsym = renamespace(sys, sym)
296+
rttsym = renamespace(sys, ttsym)
297+
for s in (sym, ttsym, rsym, rttsym)
298+
observed_syms_to_timeseries[s] = timeseries
299+
end
300+
end
301+
end
302+
303+
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
304+
keys(const_idxs), keys(nonnumeric_idxs),
305+
keys(observed_syms_to_timeseries), independent_variable_symbols(sys)))
306+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
307+
symbol_to_variable[getname(sym)] = sym
308+
end
309+
end
310+
293311
return IndexCache(
294312
unk_idxs,
295313
disc_idxs,
296314
callback_to_clocks,
297315
tunable_idxs,
298316
const_idxs,
299317
nonnumeric_idxs,
300-
observed_syms,
301-
dependent_pars,
318+
observed_syms_to_timeseries,
319+
dependent_pars_to_timeseries,
302320
disc_buffer_templates,
303321
BufferTemplate(Real, tunable_buffer_size),
304322
const_buffer_sizes,

test/odesystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,3 +1504,14 @@ end
15041504
sys2 = complete(sys; split = false)
15051505
@test ModelingToolkit.get_index_cache(sys2) === nothing
15061506
end
1507+
1508+
# https://github.com/SciML/SciMLBase.jl/issues/786
1509+
@testset "Observed variables dependent on discrete parameters" begin
1510+
@variables x(t) obs(t)
1511+
@parameters c(t)
1512+
@mtkbuild sys = ODESystem(
1513+
[D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
1514+
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
1515+
sol = solve(prob, Tsit5())
1516+
@test sol[obs] 1:7
1517+
end

0 commit comments

Comments
 (0)