Skip to content

Commit b0fd60a

Browse files
feat: add get_saved_state_idxs
1 parent 7c62e3a commit b0fd60a

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/solutions/save_idxs.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ function is_empty_indp(indp)
4949
isempty(independent_variable_symbols(indp))
5050
end
5151

52+
# Everything from this point on is public API
53+
5254
"""
5355
$(TYPEDSIGNATURES)
5456
@@ -237,6 +239,20 @@ function SavedSubsystem(indp, pobj, saved_idxs)
237239
timeseries_partition_templates, indexes_in_partition, ts_idx_to_count)
238240
end
239241

242+
"""
243+
$(TYPEDSIGNATURES)
244+
245+
Given a `SavedSubsystem`, return the subset of state indexes of the original system that are
246+
saved, in the order they are saved.
247+
"""
248+
function get_saved_state_idxs(ss::SavedSubsystem)
249+
idxs = Vector{valtype(ss.state_map)}(undef, length(ss.state_map))
250+
for (k, v) in ss.state_map
251+
idxs[v] = k
252+
end
253+
return idxs
254+
end
255+
240256
"""
241257
$(TYPEDEF)
242258

test/downstream/solution_interface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ end
193193

194194
ode_sol = solve(prob, Tsit5(); save_idxs = xidx)
195195
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
196+
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
197+
196198
# FIXME: hack for save_idxs
197199
SciMLBase.@reset ode_sol.saved_subsystem = subsys
198200

@@ -257,6 +259,7 @@ end
257259
sol = solve(prob; save_idxs = xidx)
258260
xvals = sol[x]
259261
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r])
262+
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
260263
qvals = sol.ps[q]
261264
rvals = sol.ps[r]
262265
# FIXME: hack for save_idxs
@@ -290,6 +293,7 @@ end
290293
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
291294
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])
292295
ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r])
296+
@test SciMLBase.get_saved_state_idxs(ss) == [xidx]
293297
sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys)
294298
xidx = variable_index(sys, x)
295299
qidx = parameter_index(sys, q)

0 commit comments

Comments
 (0)