@@ -32,6 +32,8 @@ struct DiscreteIndex
3232end
3333
3434const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
35+ const NonnumericMap = Dict{
36+ Union{BasicSymbolic, Symbolics. CallWithMetadata}, Tuple{Int, Int}}
3537const UnknownIndexMap = Dict{
3638 BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
3739const TunableIndexMap = Dict{BasicSymbolic,
@@ -45,20 +47,20 @@ struct IndexCache
4547 callback_to_clocks:: Dict{Any, Vector{Int}}
4648 tunable_idx:: TunableIndexMap
4749 constant_idx:: ParamIndexMap
48- nonnumeric_idx:: ParamIndexMap
50+ nonnumeric_idx:: NonnumericMap
4951 observed_syms:: Set{BasicSymbolic}
5052 dependent_pars:: Set{BasicSymbolic}
5153 discrete_buffer_sizes:: Vector{Vector{BufferTemplate}}
5254 tunable_buffer_size:: BufferTemplate
5355 constant_buffer_sizes:: Vector{BufferTemplate}
5456 nonnumeric_buffer_sizes:: Vector{BufferTemplate}
55- symbol_to_variable:: Dict{Symbol, BasicSymbolic}
57+ symbol_to_variable:: Dict{Symbol, Union{ BasicSymbolic, CallWithMetadata} }
5658end
5759
5860function IndexCache (sys:: AbstractSystem )
5961 unks = solved_unknowns (sys)
6062 unk_idxs = UnknownIndexMap ()
61- symbol_to_variable = Dict {Symbol, BasicSymbolic} ()
63+ symbol_to_variable = Dict {Symbol, Union{ BasicSymbolic, CallWithMetadata} } ()
6264
6365 let idx = 1
6466 for sym in unks
@@ -105,12 +107,11 @@ function IndexCache(sys::AbstractSystem)
105107
106108 tunable_buffers = Dict {Any, Set{BasicSymbolic}} ()
107109 constant_buffers = Dict {Any, Set{BasicSymbolic}} ()
108- nonnumeric_buffers = Dict {Any, Set{BasicSymbolic}} ()
110+ nonnumeric_buffers = Dict {Any, Set{Union{ BasicSymbolic, CallWithMetadata} }} ()
109111
110- function insert_by_type! (buffers:: Dict{Any, Set{BasicSymbolic}} , sym)
112+ function insert_by_type! (buffers:: Dict{Any, S} , sym, ctype) where {S}
111113 sym = unwrap (sym)
112- ctype = symtype (sym)
113- buf = get! (buffers, ctype, Set {BasicSymbolic} ())
114+ buf = get! (buffers, ctype, S ())
114115 push! (buf, sym)
115116 end
116117
@@ -142,7 +143,7 @@ function IndexCache(sys::AbstractSystem)
142143 clocks = get! (() -> Set {Int} (), disc_param_callbacks, sym)
143144 push! (clocks, i)
144145 else
145- insert_by_type! (constant_buffers, sym)
146+ insert_by_type! (constant_buffers, sym, symtype (sym) )
146147 end
147148 end
148149 end
@@ -197,6 +198,9 @@ function IndexCache(sys::AbstractSystem)
197198 for p in parameters (sys)
198199 p = unwrap (p)
199200 ctype = symtype (p)
201+ if ctype <: FnType
202+ ctype = fntype_to_function_type (ctype)
203+ end
200204 haskey (disc_idxs, p) && continue
201205 haskey (constant_buffers, ctype) && p in constant_buffers[ctype] && continue
202206 insert_by_type! (
@@ -212,12 +216,13 @@ function IndexCache(sys::AbstractSystem)
212216 else
213217 nonnumeric_buffers
214218 end ,
215- p
219+ p,
220+ ctype
216221 )
217222 end
218223
219- function get_buffer_sizes_and_idxs (buffers:: Dict{Any, Set{BasicSymbolic}} )
220- idxs = ParamIndexMap ()
224+ function get_buffer_sizes_and_idxs (T, buffers:: Dict )
225+ idxs = T ()
221226 buffer_sizes = BufferTemplate[]
222227 for (i, (T, buf)) in enumerate (buffers)
223228 for (j, p) in enumerate (buf)
@@ -229,13 +234,18 @@ function IndexCache(sys::AbstractSystem)
229234 idxs[rp] = (i, j)
230235 idxs[rttp] = (i, j)
231236 end
237+ if T <: Symbolics.FnType
238+ T = Any
239+ end
232240 push! (buffer_sizes, BufferTemplate (T, length (buf)))
233241 end
234242 return idxs, buffer_sizes
235243 end
236244
237- const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs (constant_buffers)
238- nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs (nonnumeric_buffers)
245+ const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs (
246+ ParamIndexMap, constant_buffers)
247+ nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs (
248+ NonnumericMap, nonnumeric_buffers)
239249
240250 tunable_idxs = TunableIndexMap ()
241251 tunable_buffer_size = 0
@@ -401,7 +411,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
401411 for temp in ic. discrete_buffer_sizes)
402412 const_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
403413 for temp in ic. constant_buffer_sizes)
404- nonnumeric_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
414+ nonnumeric_buf = Tuple (Union{BasicSymbolic, CallWithMetadata}[unwrap (variable (:DEF ))
415+ for _ in 1 : (temp. length)]
405416 for temp in ic. nonnumeric_buffer_sizes)
406417 for p in ps
407418 p = unwrap (p)
@@ -481,3 +492,6 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
481492 error (" Unhandled portion $portion " )
482493 end
483494end
495+
496+ fntype_to_function_type (:: Type{FnType{A, R, T}} ) where {A, R, T} = T
497+ fntype_to_function_type (:: Type{FnType{A, R, Nothing}} ) where {A, R} = FunctionWrapper{R, A}
0 commit comments