Skip to content

Commit 35af497

Browse files
Merge pull request #595 from SciML/runtime_generated
replace GenearlizedGenerated with RuntimeGeneratedFunctions
2 parents ae8718b + 381a4d4 commit 35af497

File tree

7 files changed

+39
-51
lines changed

7 files changed

+39
-51
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
1111
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1212
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1313
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
14-
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
1514
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1615
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1716
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
@@ -22,6 +21,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2221
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
2322
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2423
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
24+
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
2525
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2626
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2727
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -38,7 +38,6 @@ DiffEqBase = "6.48"
3838
DiffEqJump = "6.7.5"
3939
DiffRules = "0.1, 1.0"
4040
DocStringExtensions = "0.7, 0.8"
41-
GeneralizedGenerated = "0.1.4, 0.2"
4241
IfElse = "0.1"
4342
LabelledArrays = "1.3"
4443
Latexify = "0.11, 0.12, 0.13, 0.14"
@@ -47,6 +46,7 @@ MacroTools = "0.5"
4746
NaNMath = "0.3"
4847
RecursiveArrayTools = "2.3"
4948
Requires = "1.0"
49+
RuntimeGeneratedFunctions = "0.4"
5050
SafeTestsets = "0.0.1"
5151
SpecialFunctions = "0.7, 0.8, 0.9, 0.10"
5252
StaticArrays = "0.10, 0.11, 0.12"

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ using UnPack: @unpack
88
using DiffEqJump
99
using DataStructures: OrderedDict, OrderedSet
1010
using SpecialFunctions, NaNMath
11+
using RuntimeGeneratedFunctions
1112
using Base.Threads
1213
import MacroTools: splitdef, combinedef, postwalk, striplines
13-
import GeneralizedGenerated
1414
import Libdl
1515
using DocStringExtensions
1616
using Base: RefValue
1717
import IfElse
1818

19+
RuntimeGeneratedFunctions.init(@__MODULE__)
20+
1921
using RecursiveArrayTools
2022

2123
import SymbolicUtils

src/build_function.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,7 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
130130
end
131131

132132
function _build_and_inject_function(mod::Module, ex)
133-
# Generate the function, which will process the expression
134-
runtimefn = GeneralizedGenerated.mk_function(mod, ex)
135-
136-
# Extract the processed expression of the function body
137-
params = typeof(runtimefn).parameters
138-
fn_expr = GeneralizedGenerated.NGG.from_type(params[3])
139-
140-
# Inject our externally registered module functions
141-
new_expr = ModelingToolkit.inject_registered_module_functions(fn_expr)
142-
143-
# Reconstruct the RuntimeFn's Body
144-
new_body = GeneralizedGenerated.NGG.to_type(new_expr)
145-
return GeneralizedGenerated.RuntimeFn{params[1:2]..., new_body, params[4]}()
133+
@RuntimeGeneratedFunction(ModelingToolkit.inject_registered_module_functions(ex))
146134
end
147135

148136
# Detect heterogeneous element types of "arrays of matrices/sparce matrices"
@@ -186,7 +174,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
186174
187175
Generates a Julia function which can then be utilized for further evaluations.
188176
If expression=Val{false}, the return is a Julia function which utilizes
189-
GeneralizedGenerated.jl in order to be free of world-age issues.
177+
RuntimeGeneratedFunctions.jl in order to be free of world-age issues.
190178
191179
If the `Expression` is an `Operation`, the generated function is a function
192180
with a scalar output, otherwise if it's an `AbstractArray{Operation}`, the output
@@ -550,7 +538,7 @@ function _build_function(target::CTarget, eqs::Array{<:Equation}, args...;
550538
open(`gcc -fPIC -O3 -msse3 -xc -shared -o $(libpath * "." * Libdl.dlext) -`, "w") do f
551539
print(f, ex)
552540
end
553-
eval(:((du::Array{Float64},u::Array{Float64},p::Array{Float64},t::Float64) -> ccall(("diffeqf", $libpath), Cvoid, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64), du, u, p, t)))
541+
@RuntimeGeneratedFunction(:((du::Array{Float64},u::Array{Float64},p::Array{Float64},t::Float64) -> ccall(("diffeqf", $libpath), Cvoid, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64), du, u, p, t)))
554542
end
555543
end
556544

src/systems/control/controlsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ function runge_kutta_discretize(sys::ControlSystem,dt,tspan;
131131
n = length(tspan[1]:dt:tspan[2]) - 1
132132
m = length(tab.α)
133133

134-
f = eval(build_function([x.rhs for x in equations(sys)],sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys))[1])
135-
L = eval(build_function(sys.loss,sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys)))
134+
f = @RuntimeGeneratedFunction(build_function([x.rhs for x in equations(sys)],sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys))[1])
135+
L = @RuntimeGeneratedFunction(build_function(sys.loss,sys.states,sys.controls,sys.ps,sys.iv,conv = ModelingToolkit.ControlToExpr(sys)))
136136

137137
# Expand out all of the variables in time and by stages
138138
timed_vars = [[Variable(x.name,i)(sys.iv()) for i in 1:n+1] for x in states(sys)]

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
119119
kwargs...) where {iip}
120120

121121
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
122-
f_oop,f_iip = eval_expression ? ModelingToolkit.eval.(f_gen) : f_gen
122+
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
123123
f(u,p,t) = f_oop(u,p,t)
124124
f(du,u,p,t) = f_iip(du,u,p,t)
125125

126126
if tgrad
127127
tgrad_gen = generate_tgrad(sys, dvs, ps;
128128
simplify=simplify,
129129
expression=Val{eval_expression}, kwargs...)
130-
tgrad_oop,tgrad_iip = eval_expression ? ModelingToolkit.eval.(tgrad_gen) : tgrad_gen
130+
tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in tgrad_gen) : tgrad_gen
131131
_tgrad(u,p,t) = tgrad_oop(u,p,t)
132132
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
133133
else
@@ -138,7 +138,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
138138
jac_gen = generate_jacobian(sys, dvs, ps;
139139
simplify=simplify, sparse = sparse,
140140
expression=Val{eval_expression}, kwargs...)
141-
jac_oop,jac_iip = eval_expression ? ModelingToolkit.eval.(jac_gen) : jac_gen
141+
jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen
142142
_jac(u,p,t) = jac_oop(u,p,t)
143143
_jac(J,u,p,t) = jac_iip(J,u,p,t)
144144
else
@@ -149,12 +149,12 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
149149

150150
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
151151

152-
ODEFunction{iip}(DiffEqBase.EvalFunc(f),
153-
jac = _jac === nothing ? nothing : DiffEqBase.EvalFunc(_jac),
154-
tgrad = _tgrad === nothing ? nothing : DiffEqBase.EvalFunc(_tgrad),
155-
mass_matrix = _M,
156-
jac_prototype = sparse ? similar(sys.jac[],Float64) : nothing,
157-
syms = Symbol.(states(sys)))
152+
ODEFunction{iip}(f,
153+
jac = _jac === nothing ? nothing : _jac,
154+
tgrad = _tgrad === nothing ? nothing : _tgrad,
155+
mass_matrix = _M,
156+
jac_prototype = sparse ? similar(sys.jac[],Float64) : nothing,
157+
syms = Symbol.(states(sys)))
158158
end
159159

160160
"""

src/systems/diffeqs/sdesystem.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
150150
version = nothing, tgrad=false, sparse = false,
151151
jac = false, Wfact = false, eval_expression = true, kwargs...) where {iip}
152152
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
153-
f_oop,f_iip = eval_expression ? ModelingToolkit.eval.(f_gen) : f_gen
153+
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
154154
g_gen = generate_diffusion_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
155-
g_oop,g_iip = eval_expression ? ModelingToolkit.eval.(g_gen) : g_gen
155+
g_oop,g_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in g_gen) : g_gen
156156

157157
f(u,p,t) = f_oop(u,p,t)
158158
f(du,u,p,t) = f_iip(du,u,p,t)
@@ -161,7 +161,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
161161

162162
if tgrad
163163
tgrad_gen = generate_tgrad(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
164-
tgrad_oop,tgrad_iip = eval_expression ? ModelingToolkit.eval.(tgrad_gen) : tgrad_gen
164+
tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in tgrad_gen) : tgrad_gen
165165
_tgrad(u,p,t) = tgrad_oop(u,p,t)
166166
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
167167
else
@@ -170,7 +170,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
170170

171171
if jac
172172
jac_gen = generate_jacobian(sys, dvs, ps; expression=Val{eval_expression}, sparse=sparse, kwargs...)
173-
jac_oop,jac_iip = eval_expression ? ModelingToolkit.eval.(jac_gen) : jac_gen
173+
jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen
174174
_jac(u,p,t) = jac_oop(u,p,t)
175175
_jac(J,u,p,t) = jac_iip(J,u,p,t)
176176
else
@@ -179,8 +179,8 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
179179

180180
if Wfact
181181
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true; expression=Val{true}, kwargs...)
182-
Wfact_oop, Wfact_iip = eval_expression ? ModelingToolkit.eval.(tmp_Wfact) : tmp_Wfact
183-
Wfact_oop_t, Wfact_iip_t = eval_expression ? ModelingToolkit.eval.(tmp_Wfact_t) : tmp_Wfact_t
182+
Wfact_oop, Wfact_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in tmp_Wfact) : tmp_Wfact
183+
Wfact_oop_t, Wfact_iip_t = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in tmp_Wfact_t) : tmp_Wfact_t
184184
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
185185
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
186186
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
@@ -192,13 +192,13 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
192192
M = calculate_massmatrix(sys)
193193
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
194194

195-
SDEFunction{iip}(DiffEqBase.EvalFunc(f),DiffEqBase.EvalFunc(g),
196-
jac = _jac === nothing ? nothing : DiffEqBase.EvalFunc(_jac),
197-
tgrad = _tgrad === nothing ? nothing : DiffEqBase.EvalFunc(_tgrad),
198-
Wfact = _Wfact === nothing ? nothing : DiffEqBase.EvalFunc(_Wfact),
199-
Wfact_t = _Wfact_t === nothing ? nothing : DiffEqBase.EvalFunc(_Wfact_t),
200-
mass_matrix = _M,
201-
syms = Symbol.(sys.states))
195+
SDEFunction{iip}(f,g,
196+
jac = _jac === nothing ? nothing : _jac,
197+
tgrad = _tgrad === nothing ? nothing : _tgrad,
198+
Wfact = _Wfact === nothing ? nothing : _Wfact,
199+
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
200+
mass_matrix = _M,
201+
syms = Symbol.(sys.states))
202202
end
203203

204204
function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)

src/systems/jumps/jumpsystem.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ generate_affect_function(js, affect, outputidxs) = build_function(affect, states
7979
outputidxs=outputidxs)[2]
8080

8181
function assemble_vrj(js, vrj, statetoid)
82-
rate = eval(generate_rate_function(js, vrj.rate))
82+
rate = @RuntimeGeneratedFunction(generate_rate_function(js, vrj.rate))
8383
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
8484
outputidxs = ((statetoid[var] for var in outputvars)...,)
85-
affect = eval(generate_affect_function(js, vrj.affect!, outputidxs))
85+
affect = @RuntimeGeneratedFunction(generate_affect_function(js, vrj.affect!, outputidxs))
8686
VariableRateJump(rate, affect)
8787
end
8888

@@ -99,10 +99,10 @@ function assemble_vrj_expr(js, vrj, statetoid)
9999
end
100100

101101
function assemble_crj(js, crj, statetoid)
102-
rate = eval(generate_rate_function(js, crj.rate))
102+
rate = @RuntimeGeneratedFunction(generate_rate_function(js, crj.rate))
103103
outputvars = (convert(Variable,affect.lhs) for affect in crj.affect!)
104104
outputidxs = ((statetoid[var] for var in outputvars)...,)
105-
affect = eval(generate_affect_function(js, crj.affect!, outputidxs))
105+
affect = @RuntimeGeneratedFunction(generate_affect_function(js, crj.affect!, outputidxs))
106106
ConstantRateJump(rate, affect)
107107
end
108108

@@ -201,8 +201,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Tuple,
201201
else
202202
p = parammap
203203
end
204-
# EvalFunc because we know that the jump functions are generated via eval
205-
f = DiffEqBase.EvalFunc(DiffEqBase.DISCRETE_INPLACE_DEFAULT)
204+
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
206205
df = DiscreteFunction(f, syms=Symbol.(states(sys)))
207206
DiscreteProblem(df, u0, tspan, p; kwargs...)
208207
end
@@ -231,9 +230,8 @@ function DiscreteProblemExpr(sys::JumpSystem, u0map, tspan::Tuple,
231230
u0 = varmap_to_vars(u0map, states(sys))
232231
p = varmap_to_vars(parammap, parameters(sys))
233232
# identity function to make syms works
234-
# EvalFunc because we know that the jump functions are generated via eval
235233
quote
236-
f = DiffEqBase.EvalFunc(DiffEqBase.DISCRETE_INPLACE_DEFAULT)
234+
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
237235
u0 = $u0
238236
p = $p
239237
tspan = $tspan
@@ -287,7 +285,7 @@ end
287285

288286

289287
### Functions to determine which states a jump depends on
290-
function get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables)
288+
function get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables)
291289
(jump.rate isa Operation) && get_variables!(dep, jump.rate, variables)
292290
dep
293291
end

0 commit comments

Comments
 (0)