Skip to content

Commit 6ee68bb

Browse files
Merge pull request #3237 from AayushSabharwal/as/callable-discrete
fix: support callable parameters provided to discretes list of callback
2 parents a62fe60 + 472921d commit 6ee68bb

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

src/systems/index_cache.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ const TunableIndexMap = Dict{BasicSymbolic,
4040
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
4141
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}
4242

43+
const SymbolicParam = Union{BasicSymbolic, CallWithMetadata}
44+
4345
struct IndexCache
4446
unknown_idx::UnknownIndexMap
4547
# sym => (bufferidx, idx_in_buffer)
46-
discrete_idx::Dict{BasicSymbolic, DiscreteIndex}
48+
discrete_idx::Dict{SymbolicParam, DiscreteIndex}
4749
# sym => (clockidx, idx_in_clockbuffer)
4850
callback_to_clocks::Dict{Any, Vector{Int}}
4951
tunable_idx::TunableIndexMap
@@ -56,13 +58,13 @@ struct IndexCache
5658
tunable_buffer_size::BufferTemplate
5759
constant_buffer_sizes::Vector{BufferTemplate}
5860
nonnumeric_buffer_sizes::Vector{BufferTemplate}
59-
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
61+
symbol_to_variable::Dict{Symbol, SymbolicParam}
6062
end
6163

6264
function IndexCache(sys::AbstractSystem)
6365
unks = solved_unknowns(sys)
6466
unk_idxs = UnknownIndexMap()
65-
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()
67+
symbol_to_variable = Dict{Symbol, SymbolicParam}()
6668

6769
let idx = 1
6870
for sym in unks
@@ -95,18 +97,18 @@ function IndexCache(sys::AbstractSystem)
9597

9698
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
9799
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
98-
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
100+
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()
99101

100102
function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
101103
sym = unwrap(sym)
102104
buf = get!(buffers, ctype, S())
103105
push!(buf, sym)
104106
end
105107

106-
disc_param_callbacks = Dict{BasicSymbolic, Set{Int}}()
108+
disc_param_callbacks = Dict{SymbolicParam, Set{Int}}()
107109
events = vcat(continuous_events(sys), discrete_events(sys))
108110
for (i, event) in enumerate(events)
109-
discs = Set{BasicSymbolic}()
111+
discs = Set{SymbolicParam}()
110112
affs = affects(event)
111113
if !(affs isa AbstractArray)
112114
affs = [affs]
@@ -130,26 +132,32 @@ function IndexCache(sys::AbstractSystem)
130132
isequal(only(arguments(sym)), get_iv(sys))
131133
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
132134
push!(clocks, i)
133-
else
135+
elseif is_variable_floatingpoint(sym)
134136
insert_by_type!(constant_buffers, sym, symtype(sym))
137+
else
138+
stype = symtype(sym)
139+
if stype <: FnType
140+
stype = fntype_to_function_type(stype)
141+
end
142+
insert_by_type!(nonnumeric_buffers, sym, stype)
135143
end
136144
end
137145
end
138146
clock_partitions = unique(collect(values(disc_param_callbacks)))
139147
disc_symtypes = unique(symtype.(keys(disc_param_callbacks)))
140148
disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes))
141-
disc_syms_by_symtype = [BasicSymbolic[] for _ in disc_symtypes]
149+
disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes]
142150
for sym in keys(disc_param_callbacks)
143151
push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym)
144152
end
145-
disc_syms_by_symtype_by_partition = [Vector{BasicSymbolic}[] for _ in disc_symtypes]
153+
disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes]
146154
for (i, buffer) in enumerate(disc_syms_by_symtype)
147155
for partition in clock_partitions
148156
push!(disc_syms_by_symtype_by_partition[i],
149157
[sym for sym in buffer if disc_param_callbacks[sym] == partition])
150158
end
151159
end
152-
disc_idxs = Dict{BasicSymbolic, DiscreteIndex}()
160+
disc_idxs = Dict{SymbolicParam, DiscreteIndex}()
153161
callback_to_clocks = Dict{
154162
Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}()
155163
for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition)
@@ -191,6 +199,7 @@ function IndexCache(sys::AbstractSystem)
191199
end
192200
haskey(disc_idxs, p) && continue
193201
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
202+
haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue
194203
insert_by_type!(
195204
if ctype <: Real || ctype <: AbstractArray{<:Real}
196205
if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() &&

test/index_cache.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,29 @@ end
9292
reorder_dimension_by_tunables!(dst, sys, src, [r, q, p]; dim = 2)
9393
@test dst stack([vcat(4ones(4), 3ones(3), 1.0) for i in 1:5]; dims = 1)
9494
end
95+
96+
mutable struct ParamTest
97+
y::Any
98+
end
99+
(pt::ParamTest)(x) = pt.y - x
100+
@testset "Issue#3215: Callable discrete parameter" begin
101+
function update_affect!(integ, u, p, ctx)
102+
integ.p[p.p_1].y = integ.t
103+
end
104+
105+
tp1 = typeof(ParamTest(1))
106+
@parameters (p_1::tp1)(..) = ParamTest(1)
107+
@variables x(ModelingToolkit.t_nounits) = 0
108+
109+
event1 = [1.0, 2, 3] => (update_affect!, [], [p_1], [p_1], nothing)
110+
111+
@named sys = ODESystem([
112+
ModelingToolkit.D_nounits(x) ~ p_1(x)
113+
],
114+
ModelingToolkit.t_nounits;
115+
discrete_events = [event1]
116+
)
117+
ss = @test_nowarn complete(sys)
118+
@test length(parameters(ss)) == 1
119+
@test !is_timeseries_parameter(ss, p_1)
120+
end

0 commit comments

Comments
 (0)