Skip to content

Commit 2fad705

Browse files
distributed works with both dispatches
1 parent e1d4d06 commit 2fad705

File tree

4 files changed

+56
-6
lines changed

4 files changed

+56
-6
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,21 +220,32 @@ respectively.
220220
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs, ps;
221221
version = nothing,
222222
jac = false, Wfact = false) where {iip}
223-
_f = generate_function(sys, dvs, ps)
223+
f_oop,f_iip = generate_function(sys, dvs, ps)
224+
225+
f(u,p,t) = f_oop(u,p,t)
226+
f(du,u,p,t) = f_iip(du,u,p,t)
224227

225228
if jac
226-
_jac = generate_jacobian(sys, dvs, ps)
229+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps)
230+
_jac(u,p,t) = jac_oop(u,p,t)
231+
_jac(J,u,p,t) = jac_iip(J,u,p,t)
227232
else
228233
_jac = nothing
229234
end
230235

231236
if Wfact
232-
_Wfact,_Wfact_t = generate_factorized_W(sys, dvs, ps)
237+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps)
238+
Wfact_oop, Wfact_iip = tmp_Wfact
239+
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
240+
_Wfact(u,p,t) = Wfact_oop(u,p,t)
241+
_Wfact(W,u,p,t) = Wfact_iip(W,u,p,t)
242+
_Wfact_t(u,p,t) = Wfact_oop_t(u,p,t)
243+
_Wfact_t(W,u,p,t) = Wfact_iip_t(W,u,p,t)
233244
else
234245
_Wfact,_Wfact_t = nothing,nothing
235246
end
236247

237-
ODEFunction{iip}(_f,jac=_jac,
248+
ODEFunction{iip}(f,jac=_jac,
238249
Wfact = _Wfact,
239250
Wfact_t = _Wfact_t)
240251
end

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; co
4949

5050
fname = gensym(:ModelingToolkitFunction)
5151

52-
X = gensym()
52+
X = gensym(:MTIIPVar)
5353
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
5454
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
5555

5656
sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
5757
let_expr = Expr(:let, var_eqs, sys_expr)
5858

5959
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
60-
mk_function(fargs,:(),let_expr)
60+
mk_function(fargs,:(),let_expr), mk_function(:($X,$(fargs.args...)),:(),ip_let_expr)
6161
end
6262

6363
is_constant(::Constant) = true

test/distributed.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Distributed
2+
# add processes to workspace
3+
addprocs(2)
4+
5+
using ModelingToolkit
6+
using OrdinaryDiffEq
7+
8+
# create the Lorenz system
9+
@parameters t σ ρ β
10+
@variables x(t) y(t) z(t)
11+
@derivatives D'~t
12+
13+
eqs = [D(x) ~ σ*(y-x),
14+
D(y) ~ x*-z)-y,
15+
D(z) ~ x*y - β*z]
16+
17+
de = ODESystem(eqs)
18+
ode_func = ODEFunction(de, [x,y,z], [σ, ρ, β])
19+
20+
u0 = [19.,20.,50.]
21+
params = [16.,45.92,4]
22+
23+
ode_prob = ODEProblem(ode_func, u0, (0., 10.),params)
24+
25+
@everywhere begin
26+
27+
using DifferentialEquations
28+
using ModelingToolkit
29+
30+
function solve_lorenz(ode_problem)
31+
print(solve(ode_problem,Tsit5()))
32+
end
33+
end
34+
35+
solve_lorenz(ode_prob)
36+
37+
future = @spawn solve_lorenz(ode_prob)
38+
fetch(future)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ using ModelingToolkit, Test
55
@testset "Simplify Test" begin include("simplify.jl") end
66
@testset "Direct Usage Test" begin include("direct.jl") end
77
@testset "System Construction Test" begin include("system_construction.jl") end
8+
@testset "Distributed Test" begin include("distributed.jl") end

0 commit comments

Comments
 (0)