Skip to content

Commit 5567ab2

Browse files
feat: separate initial parameters for non-initialization-systems
1 parent ccb04d8 commit 5567ab2

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

src/systems/index_cache.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ struct IndexCache
4949
# sym => (clockidx, idx_in_clockbuffer)
5050
callback_to_clocks::Dict{Any, Vector{Int}}
5151
tunable_idx::TunableIndexMap
52+
initials_idx::TunableIndexMap
5253
constant_idx::ParamIndexMap
5354
nonnumeric_idx::NonnumericMap
5455
observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType}
5556
dependent_pars_to_timeseries::Dict{
5657
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}
5758
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5859
tunable_buffer_size::BufferTemplate
60+
initials_buffer_size::BufferTemplate
5961
constant_buffer_sizes::Vector{BufferTemplate}
6062
nonnumeric_buffer_sizes::Vector{BufferTemplate}
6163
symbol_to_variable::Dict{Symbol, SymbolicParam}
@@ -251,7 +253,9 @@ function IndexCache(sys::AbstractSystem)
251253

252254
tunable_idxs = TunableIndexMap()
253255
tunable_buffer_size = 0
254-
for buffers in (tunable_buffers, initial_param_buffers)
256+
bufferlist = is_initializesystem(sys) ? (tunable_buffers, initial_param_buffers) :
257+
(tunable_buffers,)
258+
for buffers in bufferlist
255259
for (i, (_, buf)) in enumerate(buffers)
256260
for (j, p) in enumerate(buf)
257261
idx = if size(p) == ()
@@ -271,6 +275,43 @@ function IndexCache(sys::AbstractSystem)
271275
end
272276
end
273277

278+
initials_idxs = TunableIndexMap()
279+
initials_buffer_size = 0
280+
if !is_initializesystem(sys)
281+
for (i, (_, buf)) in enumerate(initial_param_buffers)
282+
for (j, p) in enumerate(buf)
283+
idx = if size(p) == ()
284+
initials_buffer_size + 1
285+
else
286+
reshape(
287+
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
288+
end
289+
initials_buffer_size += length(p)
290+
initials_idxs[p] = idx
291+
initials_idxs[default_toterm(p)] = idx
292+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
293+
symbol_to_variable[getname(p)] = p
294+
symbol_to_variable[getname(default_toterm(p))] = p
295+
end
296+
end
297+
end
298+
end
299+
300+
for k in collect(keys(tunable_idxs))
301+
v = tunable_idxs[k]
302+
v isa AbstractArray || continue
303+
for (kk, vv) in zip(collect(k), v)
304+
tunable_idxs[kk] = vv
305+
end
306+
end
307+
for k in collect(keys(initials_idxs))
308+
v = initials_idxs[k]
309+
v isa AbstractArray || continue
310+
for (kk, vv) in zip(collect(k), v)
311+
initials_idxs[kk] = vv
312+
end
313+
end
314+
274315
dependent_pars_to_timeseries = Dict{
275316
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()
276317

@@ -341,12 +382,14 @@ function IndexCache(sys::AbstractSystem)
341382
disc_idxs,
342383
callback_to_clocks,
343384
tunable_idxs,
385+
initials_idxs,
344386
const_idxs,
345387
nonnumeric_idxs,
346388
observed_syms_to_timeseries,
347389
dependent_pars_to_timeseries,
348390
disc_buffer_templates,
349391
BufferTemplate(Real, tunable_buffer_size),
392+
BufferTemplate(Real, initials_buffer_size),
350393
const_buffer_sizes,
351394
nonnumeric_buffer_sizes,
352395
symbol_to_variable

src/systems/nonlinear/initializesystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,15 @@ function SciMLBase.late_binding_update_u0_p(
538538
return newu0, newp
539539
end
540540

541+
"""
542+
$(TYPEDSIGNATURES)
543+
544+
Check if the given system is an initialization system.
545+
"""
546+
function is_initializesystem(sys::AbstractSystem)
547+
sys isa NonlinearSystem && get_metadata(sys) isa InitializationSystemMetadata
548+
end
549+
541550
"""
542551
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
543552
initialization.

0 commit comments

Comments
 (0)