Skip to content

Commit 004ac47

Browse files
fix multithreaded thread safety and clean up
1 parent ae566c0 commit 004ac47

File tree

3 files changed

+41
-106
lines changed

3 files changed

+41
-106
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 14 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -174,56 +174,27 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
174174
probs = prob.prob
175175
end
176176

177-
function multithreaded_batch(batch_idx)
178-
i = II[batch_idx]
179-
iter = 1
180-
_prob = if prob.safetycopy
181-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
182-
else
183-
probs isa Vector ? probs[Threads.threadid()] : probs
184-
end
185-
new_prob = prob.prob_func(_prob,i,iter)
186-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
187-
if !(typeof(x) <: Tuple)
188-
@warn("output_func should return (out,rerun). See docs for updated details")
189-
_x = (x,false)
190-
else
191-
_x = x
192-
end
193-
rerun = _x[2]
194-
195-
while rerun
196-
iter += 1
197-
_prob = if prob.safetycopy
198-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
199-
else
200-
probs isa Vector ? probs[Threads.threadid()] : probs
201-
end
202-
new_prob = prob.prob_func(_prob,i,iter)
203-
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
204-
if !(typeof(x) <: Tuple)
205-
@warn("output_func should return (out,rerun). See docs for updated details")
206-
_x = (x,false)
207-
else
208-
_x = x
209-
end
210-
rerun = _x[2]
211-
end
212-
_x[1]
213-
end
214-
215177
#batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216-
batch_data = Vector{Any}(undef,length(II))
217178

179+
local batch_data
218180
let
219181
if length(II) == 1 || Threads.nthreads() == 1
182+
batch_data = Vector{Any}(undef,length(II))
220183
for batch_idx in axes(batch_data, 1)
221-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
184+
batch_data[batch_idx] = multithreaded_batch(batch_idx,probs,alg,II)
222185
end
223186
else
224-
Threads.@threads for batch_idx in axes(batch_data, 1)
225-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
187+
batch_data = Vector{Any}(undef,Threads.nthreads())
188+
batch_size = length(II)÷Threads.nthreads()
189+
Threads.@threads for i in 1:Threads.nthreads()
190+
if i == Threads.nthreads()
191+
I_local = II[(batch_size*(i-1)+1):end]
192+
else
193+
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
194+
end
195+
batch_data[i] = solve_batch(prob,alg,EnsembleSerial(),I_local,pmap_batch_size;kwargs...)
226196
end
197+
batch_data = reduce(vcat,batch_data)
227198
end
228199
end
229200
tighten_container_eltype(batch_data)
@@ -240,71 +211,8 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
240211
else
241212
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
242213
end
243-
thread_monte(prob,I_local,alg,i;kwargs...)
214+
solve_batch(prob,alg,EnsembleThreads(),I_local,pmap_batch_size;kwargs...)
244215
end
245216
end
246217
reduce(vcat,batch_data)
247218
end
248-
249-
function thread_monte(prob,II,alg,procid;kwargs...)
250-
251-
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
252-
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
253-
else
254-
probs = prob.prob
255-
end
256-
257-
function multithreaded_batch(j)
258-
i = II[j]
259-
iter = 1
260-
_prob = if prob.safetycopy
261-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
262-
else
263-
probs isa Vector ? probs[Threads.threadid()] : probs
264-
end
265-
new_prob = prob.prob_func(_prob,i,iter)
266-
rerun = true
267-
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
268-
if !(typeof(x) <: Tuple)
269-
@warn("output_func should return (out,rerun). See docs for updated details")
270-
_x = (x,false)
271-
else
272-
_x = x
273-
end
274-
rerun = _x[2]
275-
while rerun
276-
iter += 1
277-
_prob = if prob.safetycopy
278-
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
279-
else
280-
probs isa Vector ? probs[Threads.threadid()] : probs
281-
end
282-
new_prob = prob.prob_func(_prob,i,iter)
283-
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
284-
if !(typeof(x) <: Tuple)
285-
@warn("output_func should return (out,rerun). See docs for updated details")
286-
_x = (x,false)
287-
else
288-
_x = x
289-
end
290-
rerun = _x[2]
291-
end
292-
_x[1]
293-
end
294-
295-
#batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296-
batch_data = Vector{Any}(undef,length(II))
297-
298-
let
299-
if length(II) == 1 || Threads.nthreads() == 1
300-
for batch_idx in axes(batch_data, 1)
301-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
302-
end
303-
else
304-
Threads.@threads for batch_idx in axes(batch_data, 1)
305-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
306-
end
307-
end
308-
end
309-
tighten_container_eltype(batch_data)
310-
end
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using OrdinaryDiffEq
2+
function f(du,u,p,t)
3+
du[1] = 1.01*u[1]
4+
end
5+
u0 = [0.0, 0.0]
6+
tspan = (0.0,1.0)
7+
prob = ODEProblem(f,u0,tspan)
8+
n = 100
9+
10+
initial_conditions = range(0, stop=1, length=n)
11+
function prob_func(prob,i,repeat)
12+
prob.u0[1] = initial_conditions[i]
13+
prob
14+
end
15+
ensemble_prob = EnsembleProblem(prob,prob_func=prob_func)
16+
sim_1 = solve(ensemble_prob,Tsit5()
17+
,EnsembleThreads(),
18+
trajectories=100)
19+
sim_2 = solve(ensemble_prob,Tsit5()
20+
,EnsembleDistributed(),
21+
trajectories=100)
22+
ss_sol_1 = hcat(collect(EnsembleAnalysis.get_timepoint(sim_1,0))...);
23+
ss_sol_2 = hcat(collect(EnsembleAnalysis.get_timepoint(sim_2,0))...);
24+
25+
ss_sol_1[1,:] == initial_conditions
26+
ss_sol_2[1,:] == initial_conditions

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
3838
@time @safetestset "Null Parameters" begin include("downstream/null_params_test.jl") end
3939
@time @safetestset "Ensemble Simulations" begin include("downstream/ensemble.jl") end
4040
@time @safetestset "Ensemble Analysis" begin include("downstream/ensemble_analysis.jl") end
41+
@time @safetestset "Ensemble Thread Safety" begin include("downstream/ensemble_thread_safety.jl") end
4142
@time @safetestset "Inference Tests" begin include("downstream/inference.jl") end
4243
@time @safetestset "Default linsolve with structure" begin include("downstream/default_linsolve_structure.jl") end
4344
@time @safetestset "Callback Merging Tests" begin include("downstream/callback_merging.jl") end

0 commit comments

Comments
 (0)