Skip to content

Commit 81e1d28

Browse files
Merge pull request #3106 from AayushSabharwal/as/observed-discrete
fix: detect observed variables and dependent parameters dependent on discrete parameters
2 parents ff72509 + 55c5217 commit 81e1d28

File tree

5 files changed

+91
-35
lines changed

5 files changed

+91
-35
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,4 @@ jobs:
7272
with:
7373
file: lcov.info
7474
token: ${{ secrets.CODECOV_TOKEN }}
75-
fail_ci_if_error: true
75+
fail_ci_if_error: false

src/systems/abstractsystem.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ end
731731
function has_observed_with_lhs(sys, sym)
732732
has_observed(sys) || return false
733733
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
734-
return any(isequal(sym), ic.observed_syms)
734+
return haskey(ic.observed_syms_to_timeseries, sym)
735735
else
736736
return any(isequal(sym), [eq.lhs for eq in observed(sys)])
737737
end
@@ -740,7 +740,7 @@ end
740740
function has_parameter_dependency_with_lhs(sys, sym)
741741
has_parameter_dependencies(sys) || return false
742742
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
743-
return any(isequal(sym), ic.dependent_pars)
743+
return haskey(ic.dependent_pars_to_timeseries, unwrap(sym))
744744
else
745745
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
746746
end
@@ -762,11 +762,20 @@ for traitT in [
762762
allsyms = vars(sym; op = Symbolics.Operator)
763763
for s in allsyms
764764
s = unwrap(s)
765-
if is_variable(sys, s) || is_independent_variable(sys, s) ||
766-
has_observed_with_lhs(sys, s)
765+
if is_variable(sys, s) || is_independent_variable(sys, s)
767766
push!(ts_idxs, ContinuousTimeseries())
768767
elseif is_timeseries_parameter(sys, s)
769768
push!(ts_idxs, timeseries_parameter_index(sys, s).timeseries_idx)
769+
elseif is_time_dependent(sys) && iscall(s) && issym(operation(s)) &&
770+
is_variable(sys, operation(s)(get_iv(sys)))
771+
# DDEs case, to detect x(t - k)
772+
push!(ts_idxs, ContinuousTimeseries())
773+
elseif has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
774+
if (ts = get(ic.observed_syms_to_timeseries, s, nothing)) !== nothing
775+
union!(ts_idxs, ts)
776+
elseif (ts = get(ic.dependent_pars_to_timeseries, s, nothing)) !== nothing
777+
union!(ts_idxs, ts)
778+
end
770779
end
771780
end
772781
end

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,16 @@ function build_explicit_observed_function(sys, ts;
490490
ivs = independent_variables(sys)
491491
dep_vars = scalarize(setdiff(vars, ivs))
492492

493-
obs = param_only ? Equation[] : observed(sys)
493+
obs = observed(sys)
494+
if param_only
495+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
496+
obs = filter(obs) do eq
497+
!(ContinuousTimeseries() in ic.observed_syms_to_timeseries[eq.lhs])
498+
end
499+
else
500+
obs = Equation[]
501+
end
502+
end
494503

495504
cs = collect_constants(obs)
496505
if !isempty(cs) > 0

src/systems/index_cache.jl

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ const UnknownIndexMap = Dict{
3838
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
3939
const TunableIndexMap = Dict{BasicSymbolic,
4040
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
41+
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}
4142

4243
struct IndexCache
4344
unknown_idx::UnknownIndexMap
@@ -48,8 +49,9 @@ struct IndexCache
4849
tunable_idx::TunableIndexMap
4950
constant_idx::ParamIndexMap
5051
nonnumeric_idx::NonnumericMap
51-
observed_syms::Set{BasicSymbolic}
52-
dependent_pars::Set{Union{BasicSymbolic, CallWithMetadata}}
52+
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
53+
dependent_pars_to_timeseries::Dict{
54+
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
5355
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5456
tunable_buffer_size::BufferTemplate
5557
constant_buffer_sizes::Vector{BufferTemplate}
@@ -91,20 +93,6 @@ function IndexCache(sys::AbstractSystem)
9193
end
9294
end
9395

94-
observed_syms = Set{BasicSymbolic}()
95-
for eq in observed(sys)
96-
if symbolic_type(eq.lhs) != NotSymbolic()
97-
sym = eq.lhs
98-
ttsym = default_toterm(sym)
99-
rsym = renamespace(sys, sym)
100-
rttsym = renamespace(sys, ttsym)
101-
push!(observed_syms, sym)
102-
push!(observed_syms, ttsym)
103-
push!(observed_syms, rsym)
104-
push!(observed_syms, rttsym)
105-
end
106-
end
107-
10896
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
10997
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
11098
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
@@ -267,38 +255,77 @@ function IndexCache(sys::AbstractSystem)
267255
end
268256
end
269257

270-
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
271-
keys(const_idxs), keys(nonnumeric_idxs),
272-
observed_syms, independent_variable_symbols(sys)))
273-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
274-
symbol_to_variable[getname(sym)] = sym
275-
end
276-
end
277-
278-
dependent_pars = Set{Union{BasicSymbolic, CallWithMetadata}}()
258+
dependent_pars_to_timeseries = Dict{
259+
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()
279260

280261
for eq in parameter_dependencies(sys)
281262
sym = eq.lhs
263+
vs = vars(eq.rhs)
264+
timeseries = TimeseriesSetType()
265+
if is_time_dependent(sys)
266+
for v in vs
267+
if (idx = get(disc_idxs, v, nothing)) !== nothing
268+
push!(timeseries, idx.clock_idx)
269+
end
270+
end
271+
end
282272
ttsym = default_toterm(sym)
283273
rsym = renamespace(sys, sym)
284274
rttsym = renamespace(sys, ttsym)
285-
for s in [sym, ttsym, rsym, rttsym]
286-
push!(dependent_pars, s)
275+
for s in (sym, ttsym, rsym, rttsym)
276+
dependent_pars_to_timeseries[s] = timeseries
287277
if hasname(s) && (!iscall(s) || operation(s) != getindex)
288278
symbol_to_variable[getname(s)] = sym
289279
end
290280
end
291281
end
292282

283+
observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}()
284+
for eq in observed(sys)
285+
if symbolic_type(eq.lhs) != NotSymbolic()
286+
sym = eq.lhs
287+
vs = vars(eq.rhs; op = Nothing)
288+
timeseries = TimeseriesSetType()
289+
if is_time_dependent(sys)
290+
for v in vs
291+
if (idx = get(disc_idxs, v, nothing)) !== nothing
292+
push!(timeseries, idx.clock_idx)
293+
elseif haskey(observed_syms_to_timeseries, v)
294+
union!(timeseries, observed_syms_to_timeseries[v])
295+
elseif haskey(dependent_pars_to_timeseries, v)
296+
union!(timeseries, dependent_pars_to_timeseries[v])
297+
end
298+
end
299+
if isempty(timeseries)
300+
push!(timeseries, ContinuousTimeseries())
301+
end
302+
end
303+
ttsym = default_toterm(sym)
304+
rsym = renamespace(sys, sym)
305+
rttsym = renamespace(sys, ttsym)
306+
for s in (sym, ttsym, rsym, rttsym)
307+
observed_syms_to_timeseries[s] = timeseries
308+
end
309+
end
310+
end
311+
312+
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
313+
keys(const_idxs), keys(nonnumeric_idxs),
314+
keys(observed_syms_to_timeseries), independent_variable_symbols(sys)))
315+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
316+
symbol_to_variable[getname(sym)] = sym
317+
end
318+
end
319+
293320
return IndexCache(
294321
unk_idxs,
295322
disc_idxs,
296323
callback_to_clocks,
297324
tunable_idxs,
298325
const_idxs,
299326
nonnumeric_idxs,
300-
observed_syms,
301-
dependent_pars,
327+
observed_syms_to_timeseries,
328+
dependent_pars_to_timeseries,
302329
disc_buffer_templates,
303330
BufferTemplate(Real, tunable_buffer_size),
304331
const_buffer_sizes,

test/odesystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,3 +1504,14 @@ end
15041504
sys2 = complete(sys; split = false)
15051505
@test ModelingToolkit.get_index_cache(sys2) === nothing
15061506
end
1507+
1508+
# https://github.com/SciML/SciMLBase.jl/issues/786
1509+
@testset "Observed variables dependent on discrete parameters" begin
1510+
@variables x(t) obs(t)
1511+
@parameters c(t)
1512+
@mtkbuild sys = ODESystem(
1513+
[D(x) ~ c * cos(x), obs ~ c], t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
1514+
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
1515+
sol = solve(prob, Tsit5())
1516+
@test sol[obs] 1:7
1517+
end

0 commit comments

Comments
 (0)