Skip to content

Commit 080af2c

Browse files
fix: propagate checkbounds information to observed function generation
1 parent 0648e5c commit 080af2c

File tree

7 files changed

+29
-15
lines changed

7 files changed

+29
-15
lines changed

src/systems/abstractsystem.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ end
785785
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true
786786

787787
function SymbolicIndexingInterface.observed(
788-
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
788+
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
789789
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
790790
if sym isa Symbol
791791
_sym = get(ic.symbol_to_variable, sym, nothing)
@@ -808,7 +808,8 @@ function SymbolicIndexingInterface.observed(
808808
end
809809
end
810810
end
811-
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
811+
_fn = build_explicit_observed_function(
812+
sys, sym; eval_expression, eval_module, checkbounds)
812813

813814
if is_time_dependent(sys)
814815
return _fn
@@ -1671,11 +1672,14 @@ struct ObservedFunctionCache{S}
16711672
steady_state::Bool
16721673
eval_expression::Bool
16731674
eval_module::Module
1675+
checkbounds::Bool
16741676
end
16751677

16761678
function ObservedFunctionCache(
1677-
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
1678-
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
1679+
sys; steady_state = false, eval_expression = false,
1680+
eval_module = @__MODULE__, checkbounds = true)
1681+
return ObservedFunctionCache(
1682+
sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
16791683
end
16801684

16811685
# This is hit because ensemble problems do a deepcopy
@@ -1685,7 +1689,9 @@ function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
16851689
steady_state = ofc.steady_state
16861690
eval_expression = ofc.eval_expression
16871691
eval_module = ofc.eval_module
1688-
newofc = ObservedFunctionCache(sys, dict, steady_state, eval_expression, eval_module)
1692+
checkbounds = ofc.checkbounds
1693+
newofc = ObservedFunctionCache(
1694+
sys, dict, steady_state, eval_expression, eval_module, checkbounds)
16891695
stackdict[ofc] = newofc
16901696
return newofc
16911697
end
@@ -1694,7 +1700,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
16941700
obs = get!(ofc.dict, value(obsvar)) do
16951701
SymbolicIndexingInterface.observed(
16961702
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
1697-
eval_module = ofc.eval_module)
1703+
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
16981704
end
16991705
if ofc.steady_state
17001706
obs = let fn = obs

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
436436
ArrayInterface.restructure(u0 .* u0', M)
437437
end
438438

439-
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
439+
observedfun = ObservedFunctionCache(
440+
sys; steady_state, eval_expression, eval_module, checkbounds)
440441

441442
jac_prototype = if sparse
442443
uElType = u0 === nothing ? Float64 : eltype(u0)
@@ -522,7 +523,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
522523
_jac = nothing
523524
end
524525

525-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
526+
observedfun = ObservedFunctionCache(
527+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
526528

527529
jac_prototype = if sparse
528530
uElType = u0 === nothing ? Float64 : eltype(u0)

src/systems/diffeqs/sdesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,8 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
613613
M = calculate_massmatrix(sys)
614614
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
615615

616-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
616+
observedfun = ObservedFunctionCache(
617+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
617618

618619
SDEFunction{iip, specialize}(f, g;
619620
sys = sys,

src/systems/discrete_system/discrete_system.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,8 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
353353
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
354354
end
355355

356-
observedfun = ObservedFunctionCache(sys)
356+
observedfun = ObservedFunctionCache(
357+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
357358

358359
DiscreteFunction{iip, specialize}(f;
359360
sys = sys,

src/systems/jumps/jumpsystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
429429
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
430430
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
431431

432-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
432+
observedfun = ObservedFunctionCache(
433+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
433434

434435
df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
435436
DiscreteProblem(df, u0, tspan, p; kwargs...)
@@ -527,7 +528,8 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
527528
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
528529
check_length = false)
529530
f = (du, u, p, t) -> (du .= 0; nothing)
530-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
531+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module,
532+
checkbounds = get(kwargs, :checkbounds, false))
531533
df = ODEFunction(f; sys, observed = observedfun)
532534
return ODEProblem(df, u0, tspan, p; kwargs...)
533535
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
369369
_jac = nothing
370370
end
371371

372-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
372+
observedfun = ObservedFunctionCache(
373+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
373374

374375
if length(dvs) == length(equations(sys))
375376
resid_prototype = nothing
@@ -411,7 +412,8 @@ function SciMLBase.IntervalNonlinearFunction(
411412
f(u, p) = f_oop(u, p)
412413
f(u, p::MTKParameters) = f_oop(u, p...)
413414

414-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
415+
observedfun = ObservedFunctionCache(
416+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
415417

416418
IntervalNonlinearFunction{false}(
417419
f; observed = observedfun, sys = sys, initialization_data)

src/systems/optimization/optimizationsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
420420
hess_prototype = nothing
421421
end
422422

423-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
423+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)
424424

425425
if length(cstr) > 0
426426
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)

0 commit comments

Comments
 (0)