@@ -40,10 +40,12 @@ const TunableIndexMap = Dict{BasicSymbolic,
40
40
Union{Int, UnitRange{Int}, Base. ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
41
41
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}
42
42
43
+ const SymbolicParam = Union{BasicSymbolic, CallWithMetadata}
44
+
43
45
struct IndexCache
44
46
unknown_idx:: UnknownIndexMap
45
47
# sym => (bufferidx, idx_in_buffer)
46
- discrete_idx:: Dict{BasicSymbolic , DiscreteIndex}
48
+ discrete_idx:: Dict{SymbolicParam , DiscreteIndex}
47
49
# sym => (clockidx, idx_in_clockbuffer)
48
50
callback_to_clocks:: Dict{Any, Vector{Int}}
49
51
tunable_idx:: TunableIndexMap
@@ -56,13 +58,13 @@ struct IndexCache
56
58
tunable_buffer_size:: BufferTemplate
57
59
constant_buffer_sizes:: Vector{BufferTemplate}
58
60
nonnumeric_buffer_sizes:: Vector{BufferTemplate}
59
- symbol_to_variable:: Dict{Symbol, Union{BasicSymbolic, CallWithMetadata} }
61
+ symbol_to_variable:: Dict{Symbol, SymbolicParam }
60
62
end
61
63
62
64
function IndexCache (sys:: AbstractSystem )
63
65
unks = solved_unknowns (sys)
64
66
unk_idxs = UnknownIndexMap ()
65
- symbol_to_variable = Dict {Symbol, Union{BasicSymbolic, CallWithMetadata} } ()
67
+ symbol_to_variable = Dict {Symbol, SymbolicParam } ()
66
68
67
69
let idx = 1
68
70
for sym in unks
@@ -95,18 +97,18 @@ function IndexCache(sys::AbstractSystem)
95
97
96
98
tunable_buffers = Dict {Any, Set{BasicSymbolic}} ()
97
99
constant_buffers = Dict {Any, Set{BasicSymbolic}} ()
98
- nonnumeric_buffers = Dict {Any, Set{Union{BasicSymbolic, CallWithMetadata} }} ()
100
+ nonnumeric_buffers = Dict {Any, Set{SymbolicParam }} ()
99
101
100
102
function insert_by_type! (buffers:: Dict{Any, S} , sym, ctype) where {S}
101
103
sym = unwrap (sym)
102
104
buf = get! (buffers, ctype, S ())
103
105
push! (buf, sym)
104
106
end
105
107
106
- disc_param_callbacks = Dict {BasicSymbolic , Set{Int}} ()
108
+ disc_param_callbacks = Dict {SymbolicParam , Set{Int}} ()
107
109
events = vcat (continuous_events (sys), discrete_events (sys))
108
110
for (i, event) in enumerate (events)
109
- discs = Set {BasicSymbolic } ()
111
+ discs = Set {SymbolicParam } ()
110
112
affs = affects (event)
111
113
if ! (affs isa AbstractArray)
112
114
affs = [affs]
@@ -130,26 +132,32 @@ function IndexCache(sys::AbstractSystem)
130
132
isequal (only (arguments (sym)), get_iv (sys))
131
133
clocks = get! (() -> Set {Int} (), disc_param_callbacks, sym)
132
134
push! (clocks, i)
133
- else
135
+ elseif is_variable_floatingpoint (sym)
134
136
insert_by_type! (constant_buffers, sym, symtype (sym))
137
+ else
138
+ stype = symtype (sym)
139
+ if stype <: FnType
140
+ stype = fntype_to_function_type (stype)
141
+ end
142
+ insert_by_type! (nonnumeric_buffers, sym, stype)
135
143
end
136
144
end
137
145
end
138
146
clock_partitions = unique (collect (values (disc_param_callbacks)))
139
147
disc_symtypes = unique (symtype .(keys (disc_param_callbacks)))
140
148
disc_symtype_idx = Dict (disc_symtypes .=> eachindex (disc_symtypes))
141
- disc_syms_by_symtype = [BasicSymbolic [] for _ in disc_symtypes]
149
+ disc_syms_by_symtype = [SymbolicParam [] for _ in disc_symtypes]
142
150
for sym in keys (disc_param_callbacks)
143
151
push! (disc_syms_by_symtype[disc_symtype_idx[symtype (sym)]], sym)
144
152
end
145
- disc_syms_by_symtype_by_partition = [Vector{BasicSymbolic }[] for _ in disc_symtypes]
153
+ disc_syms_by_symtype_by_partition = [Vector{SymbolicParam }[] for _ in disc_symtypes]
146
154
for (i, buffer) in enumerate (disc_syms_by_symtype)
147
155
for partition in clock_partitions
148
156
push! (disc_syms_by_symtype_by_partition[i],
149
157
[sym for sym in buffer if disc_param_callbacks[sym] == partition])
150
158
end
151
159
end
152
- disc_idxs = Dict {BasicSymbolic , DiscreteIndex} ()
160
+ disc_idxs = Dict {SymbolicParam , DiscreteIndex} ()
153
161
callback_to_clocks = Dict{
154
162
Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}()
155
163
for (typei, disc_syms_by_partition) in enumerate (disc_syms_by_symtype_by_partition)
@@ -191,6 +199,7 @@ function IndexCache(sys::AbstractSystem)
191
199
end
192
200
haskey (disc_idxs, p) && continue
193
201
haskey (constant_buffers, ctype) && p in constant_buffers[ctype] && continue
202
+ haskey (nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue
194
203
insert_by_type! (
195
204
if ctype <: Real || ctype <: AbstractArray{<:Real}
196
205
if istunable (p, true ) && Symbolics. shape (p) != Symbolics. Unknown () &&
0 commit comments