Skip to content

Commit 976c22e

Browse files
committed
switch to header function
1 parent 595a530 commit 976c22e

File tree

2 files changed

+48
-49
lines changed

2 files changed

+48
-49
lines changed

src/build_function.jl

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,50 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
5656
_build_function(target,args...;kwargs...)
5757
end
5858

59+
function addheader(ex, fargs, iip; X=gensym(:MTIIPVar))
60+
if iip
61+
wrappedex = :(
62+
($X,$(fargs.args...)) -> begin
63+
$ex
64+
nothing
65+
end
66+
)
67+
else
68+
wrappedex = :(
69+
($(fargs.args...),) -> begin
70+
$ex
71+
end
72+
)
73+
end
74+
wrappedex
75+
end
76+
77+
function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
78+
integrator = gensym(:MTKIntegrator)
79+
if iip
80+
wrappedex = :(
81+
$integrator -> begin
82+
($X,$(fargs.args...)) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t)
83+
$ex
84+
nothing
85+
end
86+
)
87+
else
88+
wrappedex = :(
89+
$integrator -> begin
90+
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
91+
$ex
92+
end
93+
)
94+
end
95+
wrappedex
96+
end
97+
5998
# Scalar output
6099
function _build_function(target::JuliaTarget, op::Operation, args...;
61100
conv = simplified_expr, expression = Val{true},
62101
checkbounds = false, constructor=nothing,
63-
linenumbers = true, integrator_args=false)
102+
linenumbers = true, headerfun=addheader)
64103

65104
argnames = [gensym(:MTKArg) for i in 1:length(args)]
66105
arg_pairs = map(vars_to_pairs,zip(argnames,args))
@@ -74,23 +113,7 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
74113
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
75114

76115
fargs = Expr(:tuple,argnames...)
77-
78-
integrator = gensym(:MTKIntegrator)
79-
(integrator_args && !(length(args) == 3)) && error("Too many extra arguments given to build an integrator-based function; expected 3, i.e. (u,p,t), but received $(length(args)).")
80-
if integrator_args
81-
oop_ex = :(
82-
$integrator -> begin
83-
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
84-
$bounds_block
85-
end
86-
)
87-
else
88-
oop_ex = :(
89-
($(fargs.args...),) -> begin
90-
$bounds_block
91-
end
92-
)
93-
end
116+
oop_ex = headerfun(bounds_block, fargs, false)
94117

95118
if !linenumbers
96119
oop_ex = striplines(oop_ex)
@@ -106,7 +129,7 @@ end
106129
function _build_function(target::JuliaTarget, rhss, args...;
107130
conv = simplified_expr, expression = Val{true},
108131
checkbounds = false, constructor=nothing,
109-
linenumbers = false, multithread=false, integrator_args=false)
132+
linenumbers = false, multithread=false, headerfun=addheader)
110133
argnames = [gensym(:MTKArg) for i in 1:length(args)]
111134
arg_pairs = map(vars_to_pairs,zip(argnames,args))
112135
ls = reduce(vcat,first.(arg_pairs))
@@ -185,36 +208,10 @@ function _build_function(target::JuliaTarget, rhss, args...;
185208
return construct(X)
186209
end
187210
)
188-
integrator = gensym(:MTKIntegrator)
189-
(integrator_args && !(length(args) == 3)) && error("Too many extra arguments given to build an integrator-based function; expected 3, i.e. (u,p,t), but received $(length(args)).")
190-
if integrator_args
191-
oop_ex = :(
192-
$integrator -> begin
193-
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
194-
$oop_body_block
195-
end
196-
)
197-
iip_ex = :(
198-
$integrator -> begin
199-
($X,$(fargs.args...),) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t)
200-
$ip_bounds_block
201-
nothing
202-
end
203-
)
204-
else
205-
oop_ex = :(
206-
($(fargs.args...),) -> begin
207-
$oop_body_block
208-
end
209-
)
210-
iip_ex = :(
211-
($X,$(fargs.args...)) -> begin
212-
$ip_bounds_block
213-
nothing
214-
end
215-
)
216-
end
217211

212+
oop_ex = headerfun(oop_body_block, fargs, false)
213+
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)
214+
218215
if !linenumbers
219216
oop_ex = striplines(oop_ex)
220217
iip_ex = striplines(iip_ex)

src/systems/jumps/jumpsystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
1414
JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems)
1515
end
1616

17+
18+
1719
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
1820
independent_variable(js),
1921
expression=Val{false})
@@ -22,7 +24,7 @@ generate_affect_function(js, affect) = build_function(affect, states(js),
2224
parameters(js),
2325
independent_variable(js),
2426
expression=Val{false},
25-
integrator_args=true)[2]
27+
headerfun=add_integrator_header)[2]
2628
function assemble_vrj(js, vrj)
2729
rate = generate_rate_function(js, vrj.rate)
2830
affect = generate_affect_function(js, vrj.affect!)

0 commit comments

Comments
 (0)