Skip to content

Commit 3e43896

Browse files
avoid a bunch of multithreading overheads
1 parent b1c553c commit 3e43896

File tree

1 file changed

+53
-11
lines changed

1 file changed

+53
-11
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,21 @@ function solve_batch(prob,alg,::EnsembleSerial,II,pmap_batch_size;kwargs...)
168168
end
169169

170170
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kwargs...)
171+
172+
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
173+
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
174+
else
175+
probs = prob.prob
176+
end
177+
171178
function multithreaded_batch(batch_idx)
172179
i = II[batch_idx]
173180
iter = 1
174-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
181+
_prob = if prob.safetycopy
182+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
183+
else
184+
probs isa Vector ? probs[Threads.threadid()] : probs
185+
end
175186
new_prob = prob.prob_func(_prob,i,iter)
176187
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
177188
if !(typeof(x) <: Tuple)
@@ -184,9 +195,13 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
184195

185196
while rerun
186197
iter += 1
187-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
198+
_prob = if prob.safetycopy
199+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
200+
else
201+
probs isa Vector ? probs[Threads.threadid()] : probs
202+
end
188203
new_prob = prob.prob_func(_prob,i,iter)
189-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
204+
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
190205
if !(typeof(x) <: Tuple)
191206
@warn("output_func should return (out,rerun). See docs for updated details")
192207
_x = (x,false)
@@ -200,8 +215,14 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
200215

201216
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
202217
let
203-
Threads.@threads for batch_idx in axes(batch_data, 1)
204-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
218+
if length(II) == 1
219+
for batch_idx in axes(batch_data, 1)
220+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
221+
end
222+
else
223+
Threads.@threads for batch_idx in axes(batch_data, 1)
224+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
225+
end
205226
end
206227
end
207228
batch_data
@@ -225,13 +246,24 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
225246
end
226247

227248
function thread_monte(prob,II,alg,procid;kwargs...)
249+
250+
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
251+
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
252+
else
253+
probs = prob.prob
254+
end
255+
228256
function multithreaded_batch(j)
229257
i = II[j]
230258
iter = 1
231-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
259+
_prob = if prob.safetycopy
260+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
261+
else
262+
probs isa Vector ? probs[Threads.threadid()] : probs
263+
end
232264
new_prob = prob.prob_func(_prob,i,iter)
233265
rerun = true
234-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
266+
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
235267
if !(typeof(x) <: Tuple)
236268
@warn("output_func should return (out,rerun). See docs for updated details")
237269
_x = (x,false)
@@ -241,9 +273,13 @@ function thread_monte(prob,II,alg,procid;kwargs...)
241273
rerun = _x[2]
242274
while rerun
243275
iter += 1
244-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
276+
_prob = if prob.safetycopy
277+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
278+
else
279+
probs isa Vector ? probs[Threads.threadid()] : probs
280+
end
245281
new_prob = prob.prob_func(_prob,i,iter)
246-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
282+
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
247283
if !(typeof(x) <: Tuple)
248284
@warn("output_func should return (out,rerun). See docs for updated details")
249285
_x = (x,false)
@@ -258,8 +294,14 @@ function thread_monte(prob,II,alg,procid;kwargs...)
258294
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
259295

260296
let
261-
Threads.@threads for j in 1:length(II)
262-
batch_data[j] = multithreaded_batch(j)
297+
if length(II) == 1
298+
for batch_idx in axes(batch_data, 1)
299+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
300+
end
301+
else
302+
Threads.@threads for batch_idx in axes(batch_data, 1)
303+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
304+
end
263305
end
264306
end
265307
batch_data

0 commit comments

Comments
 (0)