@@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
4444 return DiffEqArray (typeof (TupleOfArraysWrapper (vt))[], t, (1 , 1 ))
4545end
4646
47- function is_empty_indp (indp)
48- isempty (variable_symbols (indp)) && isempty (parameter_symbols (indp)) &&
49- isempty (independent_variable_symbols (indp))
47+ function get_root_indp (indp)
48+ if hasmethod (symbolic_container, Tuple{typeof (indp)}) && (sc = symbolic_container (indp)) != = indp
49+ return get_root_indp (sc)
50+ end
51+ return indp
5052end
5153
5254# Everything from this point on is public API
@@ -105,17 +107,26 @@ struct SavedSubsystem{V, T, M, I, P, Q, C}
105107 partition_count:: C
106108end
107109
108- function SavedSubsystem (indp, pobj, saved_idxs)
109- # nothing saved
110- if saved_idxs === nothing || isempty (saved_idxs)
110+ SavedSubsystem (indp, pobj, :: Nothing ) = nothing
111+
112+ function SavedSubsystem (indp, pobj, idx:: Int )
113+ _indp = get_root_indp (indp)
114+ if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
111115 return nothing
112116 end
117+ state_map = Dict (1 => idx)
118+ return SavedSubsystem (state_map, nothing , nothing , nothing , nothing , nothing , nothing )
119+ end
113120
114- # this is required because problems with no system have an empty `SymbolCache`
115- # as their symbolic container.
116- if is_empty_indp (indp)
121+ function SavedSubsystem (indp, pobj, saved_idxs :: Union{AbstractArray, Tuple} )
122+ _indp = get_root_indp (indp)
123+ if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
117124 return nothing
118125 end
126+ if eltype (saved_idxs) == Int
127+ state_map = Dict {Int, Int} (v => k for (k, v) in enumerate (saved_idxs))
128+ return SavedSubsystem (state_map, nothing , nothing , nothing , nothing , nothing , nothing )
129+ end
119130
120131 # array state symbolics must be scalarized
121132 saved_idxs = collect (Iterators. flatten (map (saved_idxs) do sym
@@ -357,29 +368,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so
357368The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case
358369one is not required. `save_idxs` may be a scalar or `nothing`.
359370"""
371+ get_save_idxs_and_saved_subsystem (prob, :: Nothing ) = nothing , nothing
372+ function get_save_idxs_and_saved_subsystem (prob, save_idxs:: Vector{Int} )
373+ save_idxs, SavedSubsystem (prob, parameter_values (prob), save_idxs)
374+ end
375+ function get_save_idxs_and_saved_subsystem (prob, save_idx:: Int )
376+ save_idx, SavedSubsystem (prob, parameter_values (prob), save_idx)
377+ end
360378function get_save_idxs_and_saved_subsystem (prob, save_idxs)
361- if save_idxs === nothing
362- saved_subsystem = nothing
379+ if ! ( save_idxs isa AbstractArray) || symbolic_type (save_idxs) != NotSymbolic ()
380+ _save_idxs = (save_idxs,)
363381 else
364- if ! (save_idxs isa AbstractArray) || symbolic_type (save_idxs) != NotSymbolic ()
365- _save_idxs = [save_idxs]
382+ _save_idxs = save_idxs
383+ end
384+ saved_subsystem = SavedSubsystem (prob, parameter_values (prob), _save_idxs)
385+ if saved_subsystem != = nothing
386+ _save_idxs = get_saved_state_idxs (saved_subsystem)
387+ if isempty (_save_idxs)
388+ # no states to save
389+ save_idxs = Int[]
390+ elseif ! (save_idxs isa AbstractArray) ||
391+ symbolic_type (save_idxs) != NotSymbolic ()
392+ # only a single state to save, and save it as a scalar timeseries instead of
393+ # single-element array
394+ save_idxs = only (_save_idxs)
366395 else
367- _save_idxs = save_idxs
368- end
369- saved_subsystem = SavedSubsystem (prob, parameter_values (prob), _save_idxs)
370- if saved_subsystem != = nothing
371- _save_idxs = get_saved_state_idxs (saved_subsystem)
372- if isempty (_save_idxs)
373- # no states to save
374- save_idxs = Int[]
375- elseif ! (save_idxs isa AbstractArray) ||
376- symbolic_type (save_idxs) != NotSymbolic ()
377- # only a single state to save, and save it as a scalar timeseries instead of
378- # single-element array
379- save_idxs = only (_save_idxs)
380- else
381- save_idxs = _save_idxs
382- end
396+ save_idxs = _save_idxs
383397 end
384398 end
385399
0 commit comments