Skip to content

Commit eb655ed

Browse files
Merge pull request #500 from SciML/ensemble
Fix splitthreads trajectory choices and test it better
2 parents 173b3df + beb62f7 commit eb655ed

File tree

5 files changed

+66
-28
lines changed

5 files changed

+66
-28
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ function __solve(prob::AbstractEnsembleProblem,
5555
else
5656
@error "parallel_type value not recognized"
5757
end
58+
elseif alg isa EnsembleAlgorithm
59+
# Assume DifferentialEquations.jl is being used, so default alg
60+
ensemblealg = alg
61+
alg = nothing
5862
else
5963
ensemblealg = EnsembleThreads()
6064
end
@@ -102,7 +106,8 @@ end
102106

103107
function batch_func(i,prob,alg,I,kwargs...)
104108
iter = 1
105-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
109+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
110+
new_prob = prob.prob_func(_prob,i,iter)
106111
rerun = true
107112
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
108113
if !(typeof(x) <: Tuple)
@@ -114,7 +119,8 @@ function batch_func(i,prob,alg,I,kwargs...)
114119
rerun = _x[2]
115120
while rerun
116121
iter += 1
117-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
122+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
123+
new_prob = prob.prob_func(_prob,i,iter)
118124
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
119125
if !(typeof(x) <: Tuple)
120126
@warn("output_func should return (out,rerun). See docs for updated details")
@@ -127,7 +133,7 @@ function batch_func(i,prob,alg,I,kwargs...)
127133
_x[1]
128134
end
129135

130-
function solve_batch(prob,alg,::EnsembleDistributed,I,pmap_batch_size,kwargs...)
136+
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,I,pmap_batch_size,kwargs...)
131137
wp=CachingPool(workers())
132138
batch_data = let
133139
pmap(wp,I,batch_size=pmap_batch_size) do i
@@ -146,13 +152,14 @@ function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
146152
map(i->batch_data[i],1:length(batch_data))
147153
end
148154

149-
function solve_batch(prob,alg,::EnsembleThreads,I,pmap_batch_size,kwargs...)
155+
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwargs...)
150156
batch_data = Vector{Any}(undef,length(I))
151157
let
152158
Threads.@threads for batch_idx in axes(batch_data, 1)
153159
i = I[batch_idx]
154160
iter = 1
155-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
161+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
162+
new_prob = prob.prob_func(_prob,i,iter)
156163
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
157164
if !(typeof(x) <: Tuple)
158165
@warn("output_func should return (out,rerun). See docs for updated details")
@@ -164,7 +171,8 @@ function solve_batch(prob,alg,::EnsembleThreads,I,pmap_batch_size,kwargs...)
164171

165172
while rerun
166173
iter += 1
167-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
174+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
175+
new_prob = prob.prob_func(_prob,i,iter)
168176
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
169177
if !(typeof(x) <: Tuple)
170178
@warn("output_func should return (out,rerun). See docs for updated details")
@@ -182,23 +190,30 @@ end
182190

183191
function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...)
184192
wp=CachingPool(workers())
193+
N = nworkers()
194+
batch_size = length(I)÷N
185195
batch_data = let
186-
pmap(wp,1:nprocs(),batch_size=pmap_batch_size) do i
187-
thread_monte(prob,I,alg,i,kwargs...)
196+
pmap(wp,1:N,batch_size=pmap_batch_size) do i
197+
if i == N
198+
I_local = I[(batch_size*(i-1)+1):end]
199+
else
200+
I_local = I[(batch_size*(i-1)+1):(batch_size*i)]
201+
end
202+
thread_monte(prob,I_local,alg,i,kwargs...)
188203
end
189204
end
190205
_batch_data = vector_batch_data_to_arr(batch_data)
191206
end
192207

193208
function thread_monte(prob,I,alg,procid,kwargs...)
194-
start = I[1]+(procid-1)*length(I)
195-
stop = I[1]+procid*length(I)-1
196-
portion = start:stop
197-
batch_data = Vector{Any}(undef,length(portion))
209+
batch_data = Vector{Any}(undef,length(I))
198210
let
199-
Threads.@threads for i in portion
211+
j = 0
212+
Threads.@threads for i in I
213+
j += 1
200214
iter = 1
201-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
215+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
216+
new_prob = prob.prob_func(_prob,i,iter)
202217
rerun = true
203218
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
204219
if !(typeof(x) <: Tuple)
@@ -210,7 +225,8 @@ function thread_monte(prob,I,alg,procid,kwargs...)
210225
rerun = _x[2]
211226
while rerun
212227
iter += 1
213-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
228+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
229+
new_prob = prob.prob_func(_prob,i,iter)
214230
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
215231
if !(typeof(x) <: Tuple)
216232
@warn("output_func should return (out,rerun). See docs for updated details")
@@ -220,7 +236,7 @@ function thread_monte(prob,I,alg,procid,kwargs...)
220236
end
221237
rerun = _x[2]
222238
end
223-
batch_data[i - start + 1] = _x[1]
239+
batch_data[j] = _x[1]
224240
end
225241
end
226242
batch_data

src/ensemble/ensemble_problems.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@ struct EnsembleProblem{T,T2,T3,T4,T5} <: AbstractEnsembleProblem
77
output_func::T3
88
reduction::T4
99
u_init::T5
10+
safetycopy::Bool
1011
end
1112

13+
DEFAULT_PROB_FUNC(prob,i,repeat) = prob
14+
DEFAULT_OUTPUT_FUNC(sol,i) = (sol,false)
15+
DEFAULT_REDUCTION(u,data,I) = (append!(u,data),false)
1216
EnsembleProblem(prob;
13-
output_func = (sol,i)-> (sol,false),
14-
prob_func= (prob,i,repeat)->prob,
15-
reduction = (u,data,I)->(append!(u,data),false),
16-
u_init = []) =
17-
EnsembleProblem(prob,prob_func,output_func,reduction,u_init)
17+
output_func = DEFAULT_OUTPUT_FUNC,
18+
prob_func= DEFAULT_PROB_FUNC,
19+
reduction = DEFAULT_REDUCTION,
20+
u_init = [],
21+
safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
22+
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)
1823

1924
EnsembleProblem(;prob,
20-
output_func = (sol,i)-> (sol,false),
21-
prob_func= (prob,i,repeat)->prob,
22-
reduction = (u,data,I)->(append!(u,data),false),
23-
u_init = [], p = nothing) =
24-
EnsembleProblem(prob,prob_func,output_func,reduction,u_init)
25+
output_func = DEFAULT_OUTPUT_FUNC,
26+
prob_func= DEFAULT_PROB_FUNC,
27+
reduction = DEFAULT_REDUCTION,
28+
u_init = [], p = nothing, safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
29+
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Distributed
2+
addprocs(2)
3+
println("There are $(nprocs()) processes")
4+
@everywhere using OrdinaryDiffEq
5+
6+
@everywhere prob = ODEProblem((u,p,t)->1.01u,0.5,(0.0,1.0))
7+
@everywhere u0s = [rand()*prob.u0 for i in 1:2]
8+
@everywhere function prob_func(prob,i,repeat)
9+
println("Running trajectory $i")
10+
ODEProblem(prob.f,u0s[i],prob.tspan)
11+
end
12+
13+
ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
14+
sim = solve(ensemble_prob,Tsit5(),EnsembleSplitThreads(),trajectories=2)

test/downstream/ensemble.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ reduction = function (u,batch,I)
7373
u,((var(u)/sqrt(last(I)))/mean(u)<0.5) ? true : false
7474
end
7575

76-
prob2 = EnsembleProblem(prob,prob_func=prob_func,output_func=output_func,reduction=reduction,u_init=Vector{Float64}())
76+
prob2 = EnsembleProblem(prob,prob_func=prob_func,output_func=output_func,
77+
reduction=reduction,u_init=Vector{Float64}(),
78+
safetycopy=false)
7779
sim = solve(prob2,Tsit5(),trajectories=10000,batch_size=20)
7880
@test sim.converged == true
7981

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SafeTestsets
1+
using SafeTestsets, Test
22

33
const GROUP = get(ENV, "GROUP", "All")
44
const is_APPVEYOR = ( Sys.iswindows() && haskey(ENV,"APPVEYOR") )
@@ -48,6 +48,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
4848
@time @safetestset "DEDataArray" begin include("downstream/data_array_regression_tests.jl") end
4949
@time @safetestset "Concrete_solve Tests" begin include("downstream/concrete_solve_tests.jl") end
5050
@time @safetestset "AD Tests" begin include("downstream/ad_tests.jl") end
51+
@time @testset "Distributed Ensemble Tests" begin include("downstream/distributed_ensemble.jl") end
5152
end
5253

5354
if !is_APPVEYOR && GROUP == "GPU"

0 commit comments

Comments
 (0)