Skip to content

Commit 39e7e2b

Browse files
Merge pull request #953 from matthieubulte/master
Update usage of DestructuredArgs to use inbounds keyword
2 parents 7d698f1 + 5626867 commit 39e7e2b

File tree

4 files changed

+24
-19
lines changed

4 files changed

+24
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ SciMLBase = "1.3"
6565
Setfield = "0.7"
6666
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
6767
StaticArrays = "0.10, 0.11, 0.12, 1.0"
68-
SymbolicUtils = "0.10"
68+
SymbolicUtils = "0.10.1"
6969
Symbolics = "0.1.14"
7070
UnPack = "0.1, 1.0"
7171
Unitful = "1.1"

src/structural_transformation/codegen.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function partitions_dag(s::SystemStructure)
123123
sparse(I, J, true, n, n)
124124
end
125125

126-
function gen_nlsolve(sys, eqs, vars)
126+
function gen_nlsolve(sys, eqs, vars; checkbounds=true)
127127
@assert !isempty(vars)
128128
@assert length(eqs) == length(vars)
129129
rhss = map(x->x.rhs, eqs)
@@ -141,8 +141,8 @@ function gen_nlsolve(sys, eqs, vars)
141141
fname = gensym("fun")
142142
f = Func(
143143
[
144-
DestructuredArgs(vars)
145-
DestructuredArgs(params)
144+
DestructuredArgs(vars, inbounds=!checkbounds)
145+
DestructuredArgs(params, inbounds=!checkbounds)
146146
],
147147
[],
148148
isscalar ? rhss[1] : MakeArray(rhss, SVector)
@@ -160,7 +160,7 @@ function gen_nlsolve(sys, eqs, vars)
160160

161161
[
162162
fname @RuntimeGeneratedFunction(f)
163-
DestructuredArgs(vars) solver_call
163+
DestructuredArgs(vars, inbounds=!checkbounds) solver_call
164164
]
165165
end
166166

@@ -204,8 +204,8 @@ function build_torn_function(
204204
Func(
205205
[
206206
out
207-
DestructuredArgs(states)
208-
DestructuredArgs(parameters(sys))
207+
DestructuredArgs(states, inbounds=!checkbounds)
208+
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
209209
independent_variable(sys)
210210
],
211211
[],
@@ -221,7 +221,7 @@ function build_torn_function(
221221
observedfun = let sys = sys, dict = Dict()
222222
function generated_observed(obsvar, u, p, t)
223223
obs = get!(dict, value(obsvar)) do
224-
build_observed_function(sys, obsvar)
224+
build_observed_function(sys, obsvar, checkbounds=checkbounds)
225225
end
226226
obs(u, p, t)
227227
end
@@ -256,7 +256,8 @@ end
256256
function build_observed_function(
257257
sys, syms;
258258
expression=false,
259-
output_type=Array
259+
output_type=Array,
260+
checkbounds=true
260261
)
261262

262263
if (isscalar = !(syms isa Vector))
@@ -292,7 +293,7 @@ function build_observed_function(
292293
torn_eqs = map(idxs-> eqs[idxs[3]], subset)
293294
torn_vars = map(idxs->fullvars[idxs[4]], subset)
294295

295-
solves = gen_nlsolve.((sys,), torn_eqs, torn_vars)
296+
solves = gen_nlsolve.((sys,), torn_eqs, torn_vars; checkbounds=checkbounds)
296297
else
297298
solves = []
298299
end
@@ -307,8 +308,8 @@ function build_observed_function(
307308

308309
ex = Func(
309310
[
310-
DestructuredArgs(diffvars)
311-
DestructuredArgs(parameters(sys))
311+
DestructuredArgs(diffvars, inbounds=!checkbounds)
312+
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
312313
independent_variable(sys)
313314
],
314315
[],

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,19 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
158158
eval_expression = true,
159159
sparse = false, simplify=false,
160160
eval_module = @__MODULE__,
161+
checkbounds=false,
161162
kwargs...) where {iip}
162163

163-
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
164+
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, checkbounds=checkbounds, kwargs...)
164165
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen
165166
f(u,p,t) = f_oop(u,p,t)
166167
f(du,u,p,t) = f_iip(du,u,p,t)
167168

168169
if tgrad
169170
tgrad_gen = generate_tgrad(sys, dvs, ps;
170171
simplify=simplify,
171-
expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
172+
expression=Val{eval_expression}, expression_module=eval_module,
173+
checkbounds=checkbounds, kwargs...)
172174
tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen
173175
_tgrad(u,p,t) = tgrad_oop(u,p,t)
174176
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
@@ -179,7 +181,8 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
179181
if jac
180182
jac_gen = generate_jacobian(sys, dvs, ps;
181183
simplify=simplify, sparse = sparse,
182-
expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
184+
expression=Val{eval_expression}, expression_module=eval_module,
185+
checkbounds=checkbounds, kwargs...)
183186
jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen
184187
_jac(u,p,t) = jac_oop(u,p,t)
185188
_jac(J,u,p,t) = jac_iip(J,u,p,t)
@@ -194,7 +197,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
194197
observedfun = let sys = sys, dict = Dict()
195198
function generated_observed(obsvar, u, p, t)
196199
obs = get!(dict, value(obsvar)) do
197-
build_explicit_observed_function(sys, obsvar)
200+
build_explicit_observed_function(sys, obsvar; checkbounds=checkbounds)
198201
end
199202
obs(u, p, t)
200203
end

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ i.e. there are no cycles.
234234
function build_explicit_observed_function(
235235
sys, syms;
236236
expression=false,
237-
output_type=Array)
237+
output_type=Array,
238+
checkbounds=true)
238239

239240
if (isscalar = !(syms isa Vector))
240241
syms = [syms]
@@ -254,8 +255,8 @@ function build_explicit_observed_function(
254255

255256
ex = Func(
256257
[
257-
DestructuredArgs(states(sys))
258-
DestructuredArgs(parameters(sys))
258+
DestructuredArgs(states(sys), inbounds=!checkbounds)
259+
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
259260
independent_variable(sys)
260261
],
261262
[],

0 commit comments

Comments
 (0)