Skip to content

Commit ce6ca68

Browse files
fix: improve inference for non-symbolic save_idxs
1 parent 7a54aab commit ce6ca68

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

src/solutions/save_idxs.jl

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ struct SavedSubsystem{V, T, M, I, P, Q, C}
105105
partition_count::C
106106
end
107107

108+
SavedSubsystem(indp, pobj, ::Nothing) = nothing
109+
110+
function SavedSubsystem(indp, pobj, saved_idxs::Vector{Int})
111+
isempty(saved_idxs) && return nothing
112+
isempty(variable_symbols(indp)) && return nothing
113+
state_map = Dict{Int, Int}(k => v for (k, v) in enumerate(saved_idxs))
114+
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
115+
end
116+
117+
function SavedSubsystem(indp, pobj, idx::Int)
118+
state_map = Dict(1 => idx)
119+
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
120+
end
121+
108122
function SavedSubsystem(indp, pobj, saved_idxs)
109123
# nothing saved
110124
if saved_idxs === nothing || isempty(saved_idxs)
@@ -357,29 +371,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so
357371
The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case
358372
one is not required. `save_idxs` may be a scalar or `nothing`.
359373
"""
374+
get_save_idxs_and_saved_subsystem(prob, ::Nothing) = nothing, nothing
375+
function get_save_idxs_and_saved_subsystem(prob, save_idxs::Vector{Int})
376+
save_idxs, SavedSubsystem(prob, parameter_values(prob), save_idxs)
377+
end
378+
function get_save_idxs_and_saved_subsystem(prob, save_idx::Int)
379+
save_idx, SavedSubsystem(prob, parameter_values(prob), save_idx)
380+
end
360381
function get_save_idxs_and_saved_subsystem(prob, save_idxs)
361-
if save_idxs === nothing
362-
saved_subsystem = nothing
382+
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
383+
_save_idxs = [save_idxs]
363384
else
364-
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
365-
_save_idxs = [save_idxs]
385+
_save_idxs = save_idxs
386+
end
387+
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
388+
if saved_subsystem !== nothing
389+
_save_idxs = get_saved_state_idxs(saved_subsystem)
390+
if isempty(_save_idxs)
391+
# no states to save
392+
save_idxs = Int[]
393+
elseif !(save_idxs isa AbstractArray) ||
394+
symbolic_type(save_idxs) != NotSymbolic()
395+
# only a single state to save, and save it as a scalar timeseries instead of
396+
# single-element array
397+
save_idxs = only(_save_idxs)
366398
else
367-
_save_idxs = save_idxs
368-
end
369-
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
370-
if saved_subsystem !== nothing
371-
_save_idxs = get_saved_state_idxs(saved_subsystem)
372-
if isempty(_save_idxs)
373-
# no states to save
374-
save_idxs = Int[]
375-
elseif !(save_idxs isa AbstractArray) ||
376-
symbolic_type(save_idxs) != NotSymbolic()
377-
# only a single state to save, and save it as a scalar timeseries instead of
378-
# single-element array
379-
save_idxs = only(_save_idxs)
380-
else
381-
save_idxs = _save_idxs
382-
end
399+
save_idxs = _save_idxs
383400
end
384401
end
385402

0 commit comments

Comments
 (0)