Skip to content

Commit 31edc8f

Browse files
feat: better inbounds handling and propagation for ODEProblem and observed functions
1 parent 31f7a54 commit 31edc8f

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

src/systems/abstractsystem.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,13 @@ function wrap_assignments(isscalar, assignments; let_block = false)
226226
end
227227
end
228228

229+
function wrap_inbounds(isscalar)
230+
function wrapper(expr)
231+
Func(expr.args, [], :(@inbounds begin; $(toexpr(expr.body)); end))
232+
end
233+
return isscalar ? wrapper : (wrapper, wrapper)
234+
end
235+
229236
function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230237
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
231238
end
@@ -785,7 +792,7 @@ end
785792
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true
786793

787794
function SymbolicIndexingInterface.observed(
788-
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
795+
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
789796
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
790797
if sym isa Symbol
791798
_sym = get(ic.symbol_to_variable, sym, nothing)
@@ -808,7 +815,7 @@ function SymbolicIndexingInterface.observed(
808815
end
809816
end
810817
end
811-
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
818+
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds)
812819

813820
if is_time_dependent(sys)
814821
return _fn
@@ -1671,11 +1678,12 @@ struct ObservedFunctionCache{S}
16711678
steady_state::Bool
16721679
eval_expression::Bool
16731680
eval_module::Module
1681+
checkbounds::Bool
16741682
end
16751683

16761684
function ObservedFunctionCache(
1677-
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
1678-
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
1685+
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
1686+
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
16791687
end
16801688

16811689
# This is hit because ensemble problems do a deepcopy
@@ -1694,7 +1702,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
16941702
obs = get!(ofc.dict, value(obsvar)) do
16951703
SymbolicIndexingInterface.observed(
16961704
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
1697-
eval_module = ofc.eval_module)
1705+
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
16981706
end
16991707
if ofc.steady_state
17001708
obs = let fn = obs

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
137137
else
138138
(ps,)
139139
end
140+
if !get(kwargs, :checkbounds, false)
141+
wrap_code = wrap_code .∘ wrap_inbounds(false)
142+
end
140143
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) .∘
141144
wrap_parameter_dependencies(sys, false)
142145
return build_function(jac,
@@ -208,6 +211,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
208211
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
209212
t = get_iv(sys)
210213

214+
if !get(kwargs, :checkbounds, false)
215+
wrap_code = wrap_code .∘ wrap_inbounds(false)
216+
end
217+
211218
if isdde
212219
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...,
213220
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘
@@ -439,7 +446,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
439446
ArrayInterface.restructure(u0 .* u0', M)
440447
end
441448

442-
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
449+
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds)
443450

444451
jac_prototype = if sparse
445452
uElType = u0 === nothing ? Float64 : eltype(u0)

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,12 @@ function build_explicit_observed_function(sys, ts;
629629
oop_mtkp_wrapper = mtkparams_wrapper
630630
end
631631

632+
if !checkbounds
633+
inbounds_wrapper = wrap_inbounds(false)
634+
else
635+
inbounds_wrapper = (identity, identity)
636+
end
637+
632638
# Need to keep old method of building the function since it uses `output_type`,
633639
# which can't be provided to `build_function`
634640
return_value = if isscalar
@@ -641,14 +647,14 @@ function build_explicit_observed_function(sys, ts;
641647
oop_fn = Func(args, [],
642648
pre(Let(obsexprs,
643649
return_value,
644-
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
650+
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> inbounds_wrapper[1] |> toexpr
645651
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
646652

647653
if !isscalar
648654
iip_fn = build_function(ts,
649655
args...;
650656
postprocess_fbody = pre,
651-
wrap_code = mtkparams_wrapper .∘ array_wrapper .∘
657+
wrap_code = inbounds_wrapper .∘ mtkparams_wrapper .∘ array_wrapper .∘
652658
wrap_assignments(isscalar, obsexprs),
653659
expression = Val{true})[2]
654660
if !expression

0 commit comments

Comments
 (0)