diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index f6839a015..0b154d538 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -44,6 +44,13 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t) return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1)) end +function is_empty_indp(indp) + isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) && + isempty(independent_variable_symbols(indp)) +end + +# Everything from this point on is public API + """ $(TYPEDSIGNATURES) @@ -104,6 +111,12 @@ function SavedSubsystem(indp, pobj, saved_idxs) return nothing end + # this is required because problems with no system have an empty `SymbolCache` + # as their symbolic container. + if is_empty_indp(indp) + return nothing + end + # array state symbolics must be scalarized saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym if symbolic_type(sym) == NotSymbolic() @@ -226,6 +239,20 @@ function SavedSubsystem(indp, pobj, saved_idxs) timeseries_partition_templates, indexes_in_partition, ts_idx_to_count) end +""" + $(TYPEDSIGNATURES) + +Given a `SavedSubsystem`, return the subset of state indexes of the original system that are +saved, in the order they are saved. +""" +function get_saved_state_idxs(ss::SavedSubsystem) + idxs = Vector{valtype(ss.state_map)}(undef, length(ss.state_map)) + for (k, v) in ss.state_map + idxs[v] = k + end + return idxs +end + """ $(TYPEDEF) diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 4cbeb3173..12e2c9402 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -193,6 +193,8 @@ end ode_sol = solve(prob, Tsit5(); save_idxs = xidx) subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx]) + @test SciMLBase.get_saved_state_idxs(subsys) == [xidx] + # FIXME: hack for save_idxs SciMLBase.@reset ode_sol.saved_subsystem = subsys @@ -257,6 +259,7 @@ end sol = solve(prob; save_idxs = xidx) xvals = sol[x] subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r]) + @test SciMLBase.get_saved_state_idxs(subsys) == [xidx] qvals = sol.ps[q] rvals = sol.ps[r] # FIXME: hack for save_idxs @@ -290,6 +293,7 @@ end prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0]) ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r]) + @test SciMLBase.get_saved_state_idxs(ss) == [xidx] sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys) xidx = variable_index(sys, x) qidx = parameter_index(sys, q)