Skip to content

Commit f7a21e8

Browse files
committed
updating build_function for integrators
1 parent a40879a commit f7a21e8

File tree

1 file changed

+50
-20
lines changed

1 file changed

+50
-20
lines changed

src/build_function.jl

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
function _build_function(target::JuliaTarget, op::Operation, args...;
6161
conv = simplified_expr, expression = Val{true},
6262
checkbounds = false, constructor=nothing,
63-
linenumbers = true)
63+
linenumbers = true, integrator_args=false)
6464

6565
argnames = [gensym(:MTKArg) for i in 1:length(args)]
6666
arg_pairs = map(vars_to_pairs,zip(argnames,args))
@@ -75,12 +75,23 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
7575

7676
fargs = Expr(:tuple,argnames...)
7777

78-
oop_ex = :(
78+
integrator = gensym(:MTKIntegrator)
79+
(integrator_args && !(length(args) == 3)) && error("Too many extra arguments given to build an integrator-based function; expected 3, i.e. (u,p,t), but received $(length(args)).")
80+
if integrator_args
81+
oop_ex = :(
82+
$integrator -> begin
83+
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
84+
$bounds_block
85+
end
86+
)
87+
else
88+
oop_ex = :(
7989
($(fargs.args...),) -> begin
80-
$bounds_block
90+
$bounds_block
8191
end
82-
)
83-
92+
)
93+
end
94+
8495
if !linenumbers
8596
oop_ex = striplines(oop_ex)
8697
end
@@ -95,8 +106,7 @@ end
95106
function _build_function(target::JuliaTarget, rhss, args...;
96107
conv = simplified_expr, expression = Val{true},
97108
checkbounds = false, constructor=nothing,
98-
linenumbers = false, multithread=false)
99-
109+
linenumbers = false, multithread=false, integrator_args=false)
100110
argnames = [gensym(:MTKArg) for i in 1:length(args)]
101111
arg_pairs = map(vars_to_pairs,zip(argnames,args))
102112
ls = reduce(vcat,first.(arg_pairs))
@@ -165,25 +175,45 @@ function _build_function(target::JuliaTarget, rhss, args...;
165175
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
166176
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
167177

168-
oop_ex = :(
178+
oop_body_block = :(
179+
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
180+
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
181+
return $arr_bounds_block
182+
else
183+
X = $bounds_block
184+
construct = $_constructor
185+
return construct(X)
186+
end
187+
)
188+
integrator = gensym(:MTKIntegrator)
189+
(integrator_args && !(length(args) == 3)) && error("Too many extra arguments given to build an integrator-based function; expected 3, i.e. (u,p,t), but received $(length(args)).")
190+
if integrator_args
191+
oop_ex = :(
192+
$integrator -> begin
193+
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
194+
$oop_body_block
195+
end
196+
)
197+
iip_ex = :(
198+
$integrator -> begin
199+
($X,$(fargs.args...),) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t)
200+
$ip_bounds_block
201+
nothing
202+
end
203+
)
204+
else
205+
oop_ex = :(
169206
($(fargs.args...),) -> begin
170-
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
171-
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
172-
return $arr_bounds_block
173-
else
174-
X = $bounds_block
175-
construct = $_constructor
176-
return construct(X)
177-
end
207+
$oop_body_block
178208
end
179-
)
180-
181-
iip_ex = :(
209+
)
210+
iip_ex = :(
182211
($X,$(fargs.args...)) -> begin
183212
$ip_bounds_block
184213
nothing
185214
end
186-
)
215+
)
216+
end
187217

188218
if !linenumbers
189219
oop_ex = striplines(oop_ex)

0 commit comments

Comments
 (0)