@@ -32,6 +32,8 @@ struct DiscreteIndex
3232end 
3333
3434const  ParamIndexMap =  Dict{BasicSymbolic, Tuple{Int, Int}}
35+ const  NonnumericMap =  Dict{
36+     Union{BasicSymbolic, Symbolics. CallWithMetadata}, Tuple{Int, Int}}
3537const  UnknownIndexMap =  Dict{
3638    BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
3739const  TunableIndexMap =  Dict{BasicSymbolic,
@@ -45,20 +47,20 @@ struct IndexCache
4547    callback_to_clocks:: Dict{Any, Vector{Int}} 
4648    tunable_idx:: TunableIndexMap 
4749    constant_idx:: ParamIndexMap 
48-     nonnumeric_idx:: ParamIndexMap 
50+     nonnumeric_idx:: NonnumericMap 
4951    observed_syms:: Set{BasicSymbolic} 
5052    dependent_pars:: Set{BasicSymbolic} 
5153    discrete_buffer_sizes:: Vector{Vector{BufferTemplate}} 
5254    tunable_buffer_size:: BufferTemplate 
5355    constant_buffer_sizes:: Vector{BufferTemplate} 
5456    nonnumeric_buffer_sizes:: Vector{BufferTemplate} 
55-     symbol_to_variable:: Dict{Symbol, BasicSymbolic} 
57+     symbol_to_variable:: Dict{Symbol, Union{ BasicSymbolic, CallWithMetadata} } 
5658end 
5759
5860function  IndexCache (sys:: AbstractSystem )
5961    unks =  solved_unknowns (sys)
6062    unk_idxs =  UnknownIndexMap ()
61-     symbol_to_variable =  Dict {Symbol, BasicSymbolic} ()
63+     symbol_to_variable =  Dict {Symbol, Union{ BasicSymbolic, CallWithMetadata} } ()
6264
6365    let  idx =  1 
6466        for  sym in  unks
@@ -105,12 +107,11 @@ function IndexCache(sys::AbstractSystem)
105107
106108    tunable_buffers =  Dict {Any, Set{BasicSymbolic}} ()
107109    constant_buffers =  Dict {Any, Set{BasicSymbolic}} ()
108-     nonnumeric_buffers =  Dict {Any, Set{BasicSymbolic}} ()
110+     nonnumeric_buffers =  Dict {Any, Set{Union{ BasicSymbolic, CallWithMetadata} }} ()
109111
110-     function  insert_by_type! (buffers:: Dict{Any, Set{BasicSymbolic}}  , sym) 
112+     function  insert_by_type! (buffers:: Dict{Any, S}  , sym, ctype)  where  {S} 
111113        sym =  unwrap (sym)
112-         ctype =  symtype (sym)
113-         buf =  get! (buffers, ctype, Set {BasicSymbolic} ())
114+         buf =  get! (buffers, ctype, S ())
114115        push! (buf, sym)
115116    end 
116117
@@ -142,7 +143,7 @@ function IndexCache(sys::AbstractSystem)
142143                clocks =  get! (() ->  Set {Int} (), disc_param_callbacks, sym)
143144                push! (clocks, i)
144145            else 
145-                 insert_by_type! (constant_buffers, sym)
146+                 insert_by_type! (constant_buffers, sym,  symtype (sym) )
146147            end 
147148        end 
148149    end 
@@ -197,6 +198,9 @@ function IndexCache(sys::AbstractSystem)
197198    for  p in  parameters (sys)
198199        p =  unwrap (p)
199200        ctype =  symtype (p)
201+         if  ctype <:  FnType 
202+             ctype =  fntype_to_function_type (ctype)
203+         end 
200204        haskey (disc_idxs, p) &&  continue 
201205        haskey (constant_buffers, ctype) &&  p in  constant_buffers[ctype] &&  continue 
202206        insert_by_type! (
@@ -212,12 +216,13 @@ function IndexCache(sys::AbstractSystem)
212216            else 
213217                nonnumeric_buffers
214218            end ,
215-             p
219+             p,
220+             ctype
216221        )
217222    end 
218223
219-     function  get_buffer_sizes_and_idxs (buffers:: Dict{Any, Set{BasicSymbolic}}  )
220-         idxs =  ParamIndexMap ()
224+     function  get_buffer_sizes_and_idxs (T,  buffers:: Dict )
225+         idxs =  T ()
221226        buffer_sizes =  BufferTemplate[]
222227        for  (i, (T, buf)) in  enumerate (buffers)
223228            for  (j, p) in  enumerate (buf)
@@ -229,13 +234,18 @@ function IndexCache(sys::AbstractSystem)
229234                idxs[rp] =  (i, j)
230235                idxs[rttp] =  (i, j)
231236            end 
237+             if  T <:  Symbolics.FnType 
238+                 T =  Any
239+             end 
232240            push! (buffer_sizes, BufferTemplate (T, length (buf)))
233241        end 
234242        return  idxs, buffer_sizes
235243    end 
236244
237-     const_idxs, const_buffer_sizes =  get_buffer_sizes_and_idxs (constant_buffers)
238-     nonnumeric_idxs, nonnumeric_buffer_sizes =  get_buffer_sizes_and_idxs (nonnumeric_buffers)
245+     const_idxs, const_buffer_sizes =  get_buffer_sizes_and_idxs (
246+         ParamIndexMap, constant_buffers)
247+     nonnumeric_idxs, nonnumeric_buffer_sizes =  get_buffer_sizes_and_idxs (
248+         NonnumericMap, nonnumeric_buffers)
239249
240250    tunable_idxs =  TunableIndexMap ()
241251    tunable_buffer_size =  0 
@@ -401,7 +411,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
401411    for  temp in  ic. discrete_buffer_sizes)
402412    const_buf =  Tuple (BasicSymbolic[unwrap (variable (:DEF )) for  _ in  1 : (temp. length)]
403413    for  temp in  ic. constant_buffer_sizes)
404-     nonnumeric_buf =  Tuple (BasicSymbolic[unwrap (variable (:DEF )) for  _ in  1 : (temp. length)]
414+     nonnumeric_buf =  Tuple (Union{BasicSymbolic, CallWithMetadata}[unwrap (variable (:DEF ))
415+                                                                   for  _ in  1 : (temp. length)]
405416    for  temp in  ic. nonnumeric_buffer_sizes)
406417    for  p in  ps
407418        p =  unwrap (p)
@@ -481,3 +492,7 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
481492        error (" Unhandled portion $portion " 
482493    end 
483494end 
495+ 
496+ fntype_to_function_type (:: Type{FnType{A, R, T}} ) where  {A, R, T} =  T
497+ fntype_to_function_type (:: Type{FnType{A, R, Nothing}} ) where  {A, R} =  FunctionWrapper{R, A}
498+ fntype_to_function_type (:: Type{FnType{A, R}} ) where  {A, R} =  FunctionWrapper{R, A}
0 commit comments