Skip to content

Commit 0be6652

Browse files
more cleanup
1 parent 7483cda commit 0be6652

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ end
168168

169169
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kwargs...)
170170

171+
if length(II) == 1 || Threads.nthreads() == 1
172+
return solve_batch(prob,alg,EnsembleSerial(),II,pmap_batch_size;kwargs...)
173+
end
174+
171175
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
172176
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
173177
else
@@ -176,26 +180,19 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
176180

177181
#batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
178182

179-
local batch_data
183+
batch_data = Vector{Any}(undef,Threads.nthreads())
184+
batch_size = length(II)÷Threads.nthreads()
185+
180186
let
181-
if length(II) == 1 || Threads.nthreads() == 1
182-
batch_data = Vector{Any}(undef,length(II))
183-
for batch_idx in axes(batch_data, 1)
184-
batch_data[batch_idx] = multithreaded_batch(batch_idx,probs,alg,II)
185-
end
186-
else
187-
batch_data = Vector{Any}(undef,Threads.nthreads())
188-
batch_size = length(II)÷Threads.nthreads()
189-
Threads.@threads for i in 1:Threads.nthreads()
190-
if i == Threads.nthreads()
191-
I_local = II[(batch_size*(i-1)+1):end]
192-
else
193-
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
194-
end
195-
batch_data[i] = solve_batch(prob,alg,EnsembleSerial(),I_local,pmap_batch_size;kwargs...)
196-
end
197-
batch_data = reduce(vcat,batch_data)
187+
Threads.@threads for i in 1:Threads.nthreads()
188+
if i == Threads.nthreads()
189+
I_local = II[(batch_size*(i-1)+1):end]
190+
else
191+
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
192+
end
193+
batch_data[i] = solve_batch(prob,alg,EnsembleSerial(),I_local,pmap_batch_size;kwargs...)
198194
end
195+
batch_data = reduce(vcat,batch_data)
199196
end
200197
tighten_container_eltype(batch_data)
201198
end

0 commit comments

Comments
 (0)