Skip to content

Commit 37b1cfd

Browse files
feat: put Initial parameters at the end of tunables
1 parent eb6872b commit 37b1cfd

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

src/systems/index_cache.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ function IndexCache(sys::AbstractSystem)
9696
end
9797

9898
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
99+
initial_param_buffers = Dict{Any, Set{BasicSymbolic}}()
99100
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
100101
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()
101102

@@ -191,7 +192,7 @@ function IndexCache(sys::AbstractSystem)
191192
[BufferTemplate(symtype, length(buf)) for buf in disc_syms_by_partition])
192193
end
193194

194-
for p in parameters(sys)
195+
for p in parameters(sys; initial_parameters = true)
195196
p = unwrap(p)
196197
ctype = symtype(p)
197198
if ctype <: FnType
@@ -206,7 +207,11 @@ function IndexCache(sys::AbstractSystem)
206207
(ctype == Real || ctype <: AbstractFloat ||
207208
ctype <: AbstractArray{Real} ||
208209
ctype <: AbstractArray{<:AbstractFloat})
209-
tunable_buffers
210+
if iscall(p) && operation(p) isa Initial
211+
initial_param_buffers
212+
else
213+
tunable_buffers
214+
end
210215
else
211216
constant_buffers
212217
end
@@ -246,20 +251,22 @@ function IndexCache(sys::AbstractSystem)
246251

247252
tunable_idxs = TunableIndexMap()
248253
tunable_buffer_size = 0
249-
for (i, (_, buf)) in enumerate(tunable_buffers)
250-
for (j, p) in enumerate(buf)
251-
idx = if size(p) == ()
252-
tunable_buffer_size + 1
253-
else
254-
reshape(
255-
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
256-
end
257-
tunable_buffer_size += length(p)
258-
tunable_idxs[p] = idx
259-
tunable_idxs[default_toterm(p)] = idx
260-
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
261-
symbol_to_variable[getname(p)] = p
262-
symbol_to_variable[getname(default_toterm(p))] = p
254+
for buffers in (tunable_buffers, initial_param_buffers)
255+
for (i, (_, buf)) in enumerate(buffers)
256+
for (j, p) in enumerate(buf)
257+
idx = if size(p) == ()
258+
tunable_buffer_size + 1
259+
else
260+
reshape(
261+
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
262+
end
263+
tunable_buffer_size += length(p)
264+
tunable_idxs[p] = idx
265+
tunable_idxs[default_toterm(p)] = idx
266+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
267+
symbol_to_variable[getname(p)] = p
268+
symbol_to_variable[getname(default_toterm(p))] = p
269+
end
263270
end
264271
end
265272
end

0 commit comments

Comments
 (0)