@@ -40,10 +40,12 @@ const TunableIndexMap = Dict{BasicSymbolic,
4040 Union{Int, UnitRange{Int}, Base. ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
4141const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}
4242
43+ const SymbolicParam = Union{BasicSymbolic, CallWithMetadata}
44+
4345struct IndexCache
4446 unknown_idx:: UnknownIndexMap
4547 # sym => (bufferidx, idx_in_buffer)
46- discrete_idx:: Dict{BasicSymbolic , DiscreteIndex}
48+ discrete_idx:: Dict{SymbolicParam , DiscreteIndex}
4749 # sym => (clockidx, idx_in_clockbuffer)
4850 callback_to_clocks:: Dict{Any, Vector{Int}}
4951 tunable_idx:: TunableIndexMap
@@ -56,13 +58,13 @@ struct IndexCache
5658 tunable_buffer_size:: BufferTemplate
5759 constant_buffer_sizes:: Vector{BufferTemplate}
5860 nonnumeric_buffer_sizes:: Vector{BufferTemplate}
59- symbol_to_variable:: Dict{Symbol, Union{BasicSymbolic, CallWithMetadata} }
61+ symbol_to_variable:: Dict{Symbol, SymbolicParam }
6062end
6163
6264function IndexCache (sys:: AbstractSystem )
6365 unks = solved_unknowns (sys)
6466 unk_idxs = UnknownIndexMap ()
65- symbol_to_variable = Dict {Symbol, Union{BasicSymbolic, CallWithMetadata} } ()
67+ symbol_to_variable = Dict {Symbol, SymbolicParam } ()
6668
6769 let idx = 1
6870 for sym in unks
@@ -95,18 +97,18 @@ function IndexCache(sys::AbstractSystem)
9597
9698 tunable_buffers = Dict {Any, Set{BasicSymbolic}} ()
9799 constant_buffers = Dict {Any, Set{BasicSymbolic}} ()
98- nonnumeric_buffers = Dict {Any, Set{Union{BasicSymbolic, CallWithMetadata} }} ()
100+ nonnumeric_buffers = Dict {Any, Set{SymbolicParam }} ()
99101
100102 function insert_by_type! (buffers:: Dict{Any, S} , sym, ctype) where {S}
101103 sym = unwrap (sym)
102104 buf = get! (buffers, ctype, S ())
103105 push! (buf, sym)
104106 end
105107
106- disc_param_callbacks = Dict {BasicSymbolic , Set{Int}} ()
108+ disc_param_callbacks = Dict {SymbolicParam , Set{Int}} ()
107109 events = vcat (continuous_events (sys), discrete_events (sys))
108110 for (i, event) in enumerate (events)
109- discs = Set {BasicSymbolic } ()
111+ discs = Set {SymbolicParam } ()
110112 affs = affects (event)
111113 if ! (affs isa AbstractArray)
112114 affs = [affs]
@@ -130,26 +132,32 @@ function IndexCache(sys::AbstractSystem)
130132 isequal (only (arguments (sym)), get_iv (sys))
131133 clocks = get! (() -> Set {Int} (), disc_param_callbacks, sym)
132134 push! (clocks, i)
133- else
135+ elseif is_variable_floatingpoint (sym)
134136 insert_by_type! (constant_buffers, sym, symtype (sym))
137+ else
138+ stype = symtype (sym)
139+ if stype <: FnType
140+ stype = fntype_to_function_type (stype)
141+ end
142+ insert_by_type! (nonnumeric_buffers, sym, stype)
135143 end
136144 end
137145 end
138146 clock_partitions = unique (collect (values (disc_param_callbacks)))
139147 disc_symtypes = unique (symtype .(keys (disc_param_callbacks)))
140148 disc_symtype_idx = Dict (disc_symtypes .=> eachindex (disc_symtypes))
141- disc_syms_by_symtype = [BasicSymbolic [] for _ in disc_symtypes]
149+ disc_syms_by_symtype = [SymbolicParam [] for _ in disc_symtypes]
142150 for sym in keys (disc_param_callbacks)
143151 push! (disc_syms_by_symtype[disc_symtype_idx[symtype (sym)]], sym)
144152 end
145- disc_syms_by_symtype_by_partition = [Vector{BasicSymbolic }[] for _ in disc_symtypes]
153+ disc_syms_by_symtype_by_partition = [Vector{SymbolicParam }[] for _ in disc_symtypes]
146154 for (i, buffer) in enumerate (disc_syms_by_symtype)
147155 for partition in clock_partitions
148156 push! (disc_syms_by_symtype_by_partition[i],
149157 [sym for sym in buffer if disc_param_callbacks[sym] == partition])
150158 end
151159 end
152- disc_idxs = Dict {BasicSymbolic , DiscreteIndex} ()
160+ disc_idxs = Dict {SymbolicParam , DiscreteIndex} ()
153161 callback_to_clocks = Dict{
154162 Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}()
155163 for (typei, disc_syms_by_partition) in enumerate (disc_syms_by_symtype_by_partition)
@@ -191,6 +199,7 @@ function IndexCache(sys::AbstractSystem)
191199 end
192200 haskey (disc_idxs, p) && continue
193201 haskey (constant_buffers, ctype) && p in constant_buffers[ctype] && continue
202+ haskey (nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue
194203 insert_by_type! (
195204 if ctype <: Real || ctype <: AbstractArray{<:Real}
196205 if istunable (p, true ) && Symbolics. shape (p) != Symbolics. Unknown () &&
0 commit comments