Skip to content

Commit 0d263e5

Browse files
Merge pull request #847 from AayushSabharwal/as/get-save-idxs
feat: add `get_save_idxs_and_saved_subsystem`
2 parents 7b6c7ef + b592989 commit 0d263e5

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

src/solutions/save_idxs.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,40 @@ function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
347347

348348
return ps
349349
end
350+
351+
"""
352+
$(TYPEDSIGNATURES)
353+
354+
Given a SciMLProblem `prob` and (possibly symbolic) `save_idxs`, return the `save_idxs`
355+
corresponding to the state variables and a `SavedSubsystem` to pass to `build_solution`.
356+
357+
The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case
358+
one is not required. `save_idxs` may be a scalar or `nothing`.
359+
"""
360+
function get_save_idxs_and_saved_subsystem(prob, save_idxs)
361+
if save_idxs === nothing
362+
saved_subsystem = nothing
363+
else
364+
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
365+
_save_idxs = [save_idxs]
366+
else
367+
_save_idxs = save_idxs
368+
end
369+
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
370+
if saved_subsystem !== nothing
371+
_save_idxs = get_saved_state_idxs(saved_subsystem)
372+
if isempty(_save_idxs)
373+
# no states to save
374+
save_idxs = Int[]
375+
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
376+
# only a single state to save, and save it as a scalar timeseries instead of
377+
# single-element array
378+
save_idxs = only(_save_idxs)
379+
else
380+
save_idxs = _save_idxs
381+
end
382+
end
383+
end
384+
385+
return save_idxs, saved_subsystem
386+
end

test/downstream/solution_interface.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,35 @@ end
327327
parameter_values(ptc), rpidx.timeseries_idx => vals)
328328
@test newp[ridx] == 2prob.ps[r]
329329
end
330+
331+
@testset "get_save_idxs_and_saved_subsystem" begin
332+
@variables x(t) y(t)
333+
@parameters p q(t) r(t) s(t) u(t)
334+
evs = [0.1 => [q ~ q + 1, s ~ s - 1], 0.3 => [r ~ 2r, u ~ u / 2]]
335+
@mtkbuild sys = ODESystem([D(x) ~ x + p * y, D(y) ~ 2p + x^2], t, [x, y],
336+
[p, q, r, s, u], discrete_events = evs)
337+
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
338+
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])
339+
340+
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing)
341+
@test _idxs === _ss === nothing
342+
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1)
343+
@test _idxs == 1
344+
@test _ss isa SciMLBase.SavedSubsystem
345+
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1])
346+
@test _idxs == [1]
347+
@test _ss isa SciMLBase.SavedSubsystem
348+
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, x)
349+
@test _idxs == 1
350+
@test _ss isa SciMLBase.SavedSubsystem
351+
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [x])
352+
@test _idxs == [1]
353+
@test _ss isa SciMLBase.SavedSubsystem
354+
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [x, q])
355+
@test _idxs == [1]
356+
@test _ss isa SciMLBase.SavedSubsystem
357+
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [q])
358+
@test _idxs == Int[]
359+
@test _ss isa SciMLBase.SavedSubsystem
360+
end
330361
end

0 commit comments

Comments
 (0)