Skip to content

Commit 699d833

Browse files
Merge pull request #834 from AayushSabharwal/as/save-idxs-fix
feat: add `get_saved_state_idxs`, handle problems with no system in SavedSubsystem constructor
2 parents bf913e7 + b0fd60a commit 699d833

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

src/solutions/save_idxs.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
4444
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
4545
end
4646

47+
function is_empty_indp(indp)
48+
isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) &&
49+
isempty(independent_variable_symbols(indp))
50+
end
51+
52+
# Everything from this point on is public API
53+
4754
"""
4855
$(TYPEDSIGNATURES)
4956
@@ -104,6 +111,12 @@ function SavedSubsystem(indp, pobj, saved_idxs)
104111
return nothing
105112
end
106113

114+
# this is required because problems with no system have an empty `SymbolCache`
115+
# as their symbolic container.
116+
if is_empty_indp(indp)
117+
return nothing
118+
end
119+
107120
# array state symbolics must be scalarized
108121
saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym
109122
if symbolic_type(sym) == NotSymbolic()
@@ -226,6 +239,20 @@ function SavedSubsystem(indp, pobj, saved_idxs)
226239
timeseries_partition_templates, indexes_in_partition, ts_idx_to_count)
227240
end
228241

242+
"""
243+
$(TYPEDSIGNATURES)
244+
245+
Given a `SavedSubsystem`, return the subset of state indexes of the original system that are
246+
saved, in the order they are saved.
247+
"""
248+
function get_saved_state_idxs(ss::SavedSubsystem)
249+
idxs = Vector{valtype(ss.state_map)}(undef, length(ss.state_map))
250+
for (k, v) in ss.state_map
251+
idxs[v] = k
252+
end
253+
return idxs
254+
end
255+
229256
"""
230257
$(TYPEDEF)
231258

test/downstream/solution_interface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ end
193193

194194
ode_sol = solve(prob, Tsit5(); save_idxs = xidx)
195195
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
196+
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
197+
196198
# FIXME: hack for save_idxs
197199
SciMLBase.@reset ode_sol.saved_subsystem = subsys
198200

@@ -257,6 +259,7 @@ end
257259
sol = solve(prob; save_idxs = xidx)
258260
xvals = sol[x]
259261
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r])
262+
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
260263
qvals = sol.ps[q]
261264
rvals = sol.ps[r]
262265
# FIXME: hack for save_idxs
@@ -290,6 +293,7 @@ end
290293
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
291294
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])
292295
ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r])
296+
@test SciMLBase.get_saved_state_idxs(ss) == [xidx]
293297
sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys)
294298
xidx = variable_index(sys, x)
295299
qidx = parameter_index(sys, q)

0 commit comments

Comments
 (0)