Skip to content

Commit e1d4d06

Browse files
using GG it works
1 parent d8304e8 commit e1d4d06

File tree

4 files changed

+19
-50
lines changed

4 files changed

+19
-50
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1616
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
17+
GG = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
1718

1819
[compat]
1920
julia = "1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using StaticArrays, LinearAlgebra
1313

1414
using MacroTools
1515
import MacroTools: splitdef, combinedef
16-
16+
import GG
1717
using DocStringExtensions
1818

1919
"""

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -217,53 +217,26 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
217217
are used to set the order of the dependent variable and parameter vectors,
218218
respectively.
219219
"""
220-
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps,
221-
safe = Val{true};
220+
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps;
222221
version = nothing,
223222
jac = false, Wfact = false) where {iip}
224-
_f = eval(generate_function(sys, dvs, ps))
225-
out_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u,p,t)
226-
out_f_safe(du,u,p,t) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u,p,t)
227-
out_f(u,p,t) = _f(u,p,t)
228-
out_f(du,u,p,t) = _f(du,u,p,t)
223+
_f = generate_function(sys, dvs, ps)
229224

230225
if jac
231-
_jac = eval(generate_jacobian(sys, dvs, ps))
232-
jac_f_safe(u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Matrix{eltype(u)},u,p,t)
233-
jac_f_safe(J,u,p,t) = ModelingToolkit.fast_invokelatest(_jac,Nothing,J,u,p,t)
234-
jac_f(u,p,t) = _jac(u,p,t)
235-
jac_f(J,u,p,t) = _jac(J,u,p,t)
226+
_jac = generate_jacobian(sys, dvs, ps)
236227
else
237-
jac_f_safe = nothing
238-
jac_f = nothing
228+
_jac = nothing
239229
end
240230

241231
if Wfact
242-
_Wfact,_Wfact_t = eval.(generate_factorized_W(sys, dvs, ps))
243-
Wfact_f_safe(u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact,Matrix{eltype(u)},u,p,gam,t)
244-
Wfact_f_safe(J,u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact,Nothing,J,u,p,gam,t)
245-
Wfact_f_t_safe(u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact_t,Matrix{eltype(u)},u,p,gam,t)
246-
Wfact_f_t_safe(J,u,p,gam,t) = ModelingToolkit.fast_invokelatest(_Wfact_t,Nothing,J,u,p,gam,t)
247-
Wfact_f(u,p,gam,t) = _Wfact(u,p,gam,t)
248-
Wfact_f(J,u,p,gam,t) = _Wfact(J,u,p,gam,t)
249-
Wfact_f_t(u,p,gam,t) = _Wfact_t(u,p,gam,t)
250-
Wfact_f_t(J,u,p,gam,t) = _Wfact_t(J,u,p,gam,t)
232+
_Wfact,_Wfact_t = generate_factorized_W(sys, dvs, ps)
251233
else
252-
Wfact_f_safe = nothing
253-
Wfact_f_t_safe = nothing
254-
Wfact_f = nothing
255-
Wfact_f_t = nothing
234+
_Wfact,_Wfact_t = nothing,nothing
256235
end
257236

258-
if safe === Val{true}
259-
ODEFunction{iip}(out_f_safe,jac=jac_f_safe,
260-
Wfact = Wfact_f_safe,
261-
Wfact_t = Wfact_f_t_safe)
262-
else
263-
ODEFunction{iip}(out_f,jac=jac_f,
264-
Wfact = Wfact_f,
265-
Wfact_t = Wfact_f_t)
266-
end
237+
ODEFunction{iip}(_f,jac=_jac,
238+
Wfact = _Wfact,
239+
Wfact_t = _Wfact_t)
267240
end
268241

269242
function DiffEqBase.ODEFunction(sys::ODESystem, args...; kwargs...)

src/utils.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34+
mk_function(args, kwargs, body) =
35+
let Args = args |> GG.expr2typelevel,
36+
Kwargs = kwargs |> GG.expr2typelevel,
37+
Body = body |> GG.expr2typelevel
38+
GG.RuntimeFn{Args, Kwargs, Body}()
39+
end
40+
3441
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; constructor=nothing)
3542
_vs = map(x-> x isa Operation ? x.op : x, vs)
3643
_ps = map(x-> x isa Operation ? x.op : x, ps)
@@ -50,19 +57,7 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; co
5057
let_expr = Expr(:let, var_eqs, sys_expr)
5158

5259
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
53-
quote
54-
@everywhere function $fname($X,$(fargs.args...))
55-
$ip_let_expr
56-
nothing
57-
end
58-
@everywhere function $fname($(fargs.args...))
59-
X = $let_expr
60-
T = promote_type(map(typeof,X)...)
61-
convert.(T,X)
62-
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, T, $(size(rhss)...)); vec(du) .= x; du)) : constructor)
63-
construct(X)
64-
end
65-
end
60+
mk_function(fargs,:(),let_expr)
6661
end
6762

6863
is_constant(::Constant) = true

0 commit comments

Comments
 (0)