Skip to content

Commit fb2bc23

Browse files
fix: improve inference of get_save_idxs_and_saved_subsystem
1 parent 73cb4d0 commit fb2bc23

File tree

3 files changed

+19
-25
lines changed

3 files changed

+19
-25
lines changed

src/scimlfunctions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4818,8 +4818,10 @@ for S in [:ODEFunction
48184818
end
48194819
end
48204820

4821+
const EMPTY_SYMBOLCACHE = SymbolCache()
4822+
48214823
function SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction)
4822-
has_sys(fn) ? fn.sys : SymbolCache()
4824+
has_sys(fn) ? fn.sys : EMPTY_SYMBOLCACHE
48234825
end
48244826

48254827
function SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym)

src/solutions/save_idxs.jl

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
4444
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
4545
end
4646

47-
function is_empty_indp(indp)
48-
isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) &&
49-
isempty(independent_variable_symbols(indp))
47+
function get_root_indp(indp)
48+
if hasmethod(symbolic_container, Tuple{typeof(indp)}) && (sc = symbolic_container(indp)) !== indp
49+
return get_root_indp(sc)
50+
end
51+
return indp
5052
end
5153

5254
# Everything from this point on is public API
@@ -107,28 +109,19 @@ end
107109

108110
SavedSubsystem(indp, pobj, ::Nothing) = nothing
109111

110-
function SavedSubsystem(indp, pobj, saved_idxs::Vector{Int})
111-
isempty(saved_idxs) && return nothing
112-
isempty(variable_symbols(indp)) && return nothing
113-
state_map = Dict{Int, Int}(k => v for (k, v) in enumerate(saved_idxs))
114-
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
115-
end
116-
117112
function SavedSubsystem(indp, pobj, idx::Int)
118113
state_map = Dict(1 => idx)
119114
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
120115
end
121116

122-
function SavedSubsystem(indp, pobj, saved_idxs)
123-
# nothing saved
124-
if saved_idxs === nothing || isempty(saved_idxs)
117+
function SavedSubsystem(indp, pobj, saved_idxs::Union{Array, Tuple})
118+
_indp = get_root_indp(indp)
119+
if indp === EMPTY_SYMBOLCACHE || indp === nothing
125120
return nothing
126121
end
127-
128-
# this is required because problems with no system have an empty `SymbolCache`
129-
# as their symbolic container.
130-
if is_empty_indp(indp)
131-
return nothing
122+
if eltype(saved_idxs) == Int
123+
state_map = Dict{Int, Int}(k => v for (k, v) in enumerate(saved_idxs))
124+
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
132125
end
133126

134127
# array state symbolics must be scalarized
@@ -380,7 +373,7 @@ function get_save_idxs_and_saved_subsystem(prob, save_idx::Int)
380373
end
381374
function get_save_idxs_and_saved_subsystem(prob, save_idxs)
382375
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
383-
_save_idxs = [save_idxs]
376+
_save_idxs = (save_idxs,)
384377
else
385378
_save_idxs = save_idxs
386379
end

test/downstream/solution_interface.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ end
182182
xidx = variable_index(sys, x)
183183
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5])
184184

185-
@test SciMLBase.SavedSubsystem(sys, prob.p, []) ===
186-
SciMLBase.SavedSubsystem(sys, prob.p, nothing) === nothing
185+
@test SciMLBase.SavedSubsystem(sys, prob.p, nothing) === nothing
187186
@test SciMLBase.SavedSubsystem(sys, prob.p, [x, y]) === nothing
188187
@test begin
189188
ss1 = SciMLBase.SavedSubsystem(sys, prob.p, [x])
@@ -319,12 +318,12 @@ end
319318
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
320319
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])
321320

322-
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing)
321+
_idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing)
323322
@test _idxs === _ss === nothing
324-
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1)
323+
_idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1)
325324
@test _idxs == 1
326325
@test _ss isa SciMLBase.SavedSubsystem
327-
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1])
326+
_idxs, _ss = @inferred SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1])
328327
@test _idxs == [1]
329328
@test _ss isa SciMLBase.SavedSubsystem
330329
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, x)

0 commit comments

Comments
 (0)