Skip to content

Commit 8c55fe0

Browse files
fix: propagate checkbounds information to observed function generation
1 parent 0648e5c commit 8c55fe0

File tree

7 files changed

+15
-14
lines changed

7 files changed

+15
-14
lines changed

src/systems/abstractsystem.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ end
785785
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true
786786

787787
function SymbolicIndexingInterface.observed(
788-
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
788+
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
789789
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
790790
if sym isa Symbol
791791
_sym = get(ic.symbol_to_variable, sym, nothing)
@@ -808,7 +808,7 @@ function SymbolicIndexingInterface.observed(
808808
end
809809
end
810810
end
811-
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
811+
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds)
812812

813813
if is_time_dependent(sys)
814814
return _fn
@@ -1671,11 +1671,12 @@ struct ObservedFunctionCache{S}
16711671
steady_state::Bool
16721672
eval_expression::Bool
16731673
eval_module::Module
1674+
checkbounds::Bool
16741675
end
16751676

16761677
function ObservedFunctionCache(
1677-
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
1678-
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
1678+
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
1679+
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
16791680
end
16801681

16811682
# This is hit because ensemble problems do a deepcopy
@@ -1694,7 +1695,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
16941695
obs = get!(ofc.dict, value(obsvar)) do
16951696
SymbolicIndexingInterface.observed(
16961697
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
1697-
eval_module = ofc.eval_module)
1698+
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
16981699
end
16991700
if ofc.steady_state
17001701
obs = let fn = obs

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
436436
ArrayInterface.restructure(u0 .* u0', M)
437437
end
438438

439-
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
439+
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds)
440440

441441
jac_prototype = if sparse
442442
uElType = u0 === nothing ? Float64 : eltype(u0)
@@ -522,7 +522,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
522522
_jac = nothing
523523
end
524524

525-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
525+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
526526

527527
jac_prototype = if sparse
528528
uElType = u0 === nothing ? Float64 : eltype(u0)

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
613613
M = calculate_massmatrix(sys)
614614
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
615615

616-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
616+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
617617

618618
SDEFunction{iip, specialize}(f, g;
619619
sys = sys,

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
353353
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
354354
end
355355

356-
observedfun = ObservedFunctionCache(sys)
356+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
357357

358358
DiscreteFunction{iip, specialize}(f;
359359
sys = sys,

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
429429
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
430430
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
431431

432-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
432+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
433433

434434
df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
435435
DiscreteProblem(df, u0, tspan, p; kwargs...)
@@ -527,7 +527,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
527527
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
528528
check_length = false)
529529
f = (du, u, p, t) -> (du .= 0; nothing)
530-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
530+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
531531
df = ODEFunction(f; sys, observed = observedfun)
532532
return ODEProblem(df, u0, tspan, p; kwargs...)
533533
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
369369
_jac = nothing
370370
end
371371

372-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
372+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
373373

374374
if length(dvs) == length(equations(sys))
375375
resid_prototype = nothing
@@ -411,7 +411,7 @@ function SciMLBase.IntervalNonlinearFunction(
411411
f(u, p) = f_oop(u, p)
412412
f(u, p::MTKParameters) = f_oop(u, p...)
413413

414-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
414+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
415415

416416
IntervalNonlinearFunction{false}(
417417
f; observed = observedfun, sys = sys, initialization_data)

src/systems/optimization/optimizationsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
420420
hess_prototype = nothing
421421
end
422422

423-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
423+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)
424424

425425
if length(cstr) > 0
426426
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)

0 commit comments

Comments
 (0)