|
1 | 1 | using Distributed
|
2 | 2 | addprocs(2)
|
3 | 3 | println("There are $(nprocs()) processes")
|
4 |
| -@everywhere using OrdinaryDiffEq |
5 | 4 |
|
6 |
| -@everywhere prob = ODEProblem((u,p,t)->1.01u,0.5,(0.0,1.0)) |
7 |
| -@everywhere u0s = [rand()*prob.u0 for i in 1:2] |
8 |
| -@everywhere function prob_func(prob,i,repeat) |
9 |
| - println("Running trajectory $i") |
10 |
| - ODEProblem(prob.f,u0s[i],prob.tspan) |
| 5 | +@everywhere begin |
| 6 | + using Pkg |
| 7 | + Pkg.activate("downstream") |
| 8 | + Pkg.develop(PackageSpec(path=joinpath(pwd(), ".."))) |
| 9 | + Pkg.instantiate() |
| 10 | + using OrdinaryDiffEq |
| 11 | + prob = ODEProblem((u,p,t)->1.01u,0.5,(0.0,1.0)) |
| 12 | + u0s = [rand()*prob.u0 for i in 1:2] |
| 13 | + function prob_func(prob,i,repeat) |
| 14 | + println("Running trajectory $i") |
| 15 | + ODEProblem(prob.f,u0s[i],prob.tspan) |
| 16 | + end |
11 | 17 | end
|
12 | 18 |
|
13 | 19 | ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
|
14 | 20 | sim = solve(ensemble_prob,Tsit5(),EnsembleSplitThreads(),trajectories=2)
|
| 21 | + |
| 22 | +@everywhere function lorenz!(du,u,p,t) |
| 23 | + du[1] = 10.0*(u[2]-u[1]) |
| 24 | + du[2] = u[1]*(28.0-u[3]) - u[2] |
| 25 | + du[3] = u[1]*u[2] - (8/3)*u[3] |
| 26 | +end |
| 27 | + |
| 28 | +u0 = [1.0, 0.0, 0.0] |
| 29 | +tspan = (0.0, 100.0) |
| 30 | +p = [1, 2.0, 3] |
| 31 | +prob = ODEProblem(lorenz!, u0, tspan, p) |
| 32 | + |
| 33 | +@everywhere function prob_func(prob,i,repeat) |
| 34 | + prob = remake(prob, tspan=(rand(), 100.0), p=rand(3)) |
| 35 | + return prob |
| 36 | +end |
| 37 | + |
| 38 | +ensemble_prob = EnsembleProblem(prob, prob_func = prob_func, safetycopy=true) |
| 39 | + |
| 40 | +println("Running EnsembleSerial()") |
| 41 | +@test length(solve(ensemble_prob, Tsit5(), EnsembleSerial(), trajectories=100)) == 100 |
| 42 | +println("Running EnsembleThreads()") |
| 43 | +@test length(solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories=100)) == 100 |
| 44 | +println("Running EnsembleDistributed()") |
| 45 | +@test length(solve(ensemble_prob, Tsit5(), EnsembleDistributed(), trajectories=100)) == 100 |
| 46 | +println("Running EnsembleSplitThreads()") |
| 47 | +@test length(solve(ensemble_prob, Tsit5(), EnsembleSplitThreads(), trajectories=100)) == 100 |
0 commit comments