168
168
169
169
function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,II,pmap_batch_size;kwargs... )
170
170
171
+ if length (II) == 1 || Threads. nthreads () == 1
172
+ return solve_batch (prob,alg,EnsembleSerial (),II,pmap_batch_size;kwargs... )
173
+ end
174
+
171
175
if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
172
176
probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
173
177
else
@@ -176,26 +180,19 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
176
180
177
181
# batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
178
182
179
- local batch_data
183
+ batch_data = Vector {Any} (undef,Threads. nthreads ())
184
+ batch_size = length (II)÷ Threads. nthreads ()
185
+
180
186
let
181
- if length (II) == 1 || Threads. nthreads () == 1
182
- batch_data = Vector {Any} (undef,length (II))
183
- for batch_idx in axes (batch_data, 1 )
184
- batch_data[batch_idx] = multithreaded_batch (batch_idx,probs,alg,II)
185
- end
186
- else
187
- batch_data = Vector {Any} (undef,Threads. nthreads ())
188
- batch_size = length (II)÷ Threads. nthreads ()
189
- Threads. @threads for i in 1 : Threads. nthreads ()
190
- if i == Threads. nthreads ()
191
- I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
192
- else
193
- I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
194
- end
195
- batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
196
- end
197
- batch_data = reduce (vcat,batch_data)
187
+ Threads. @threads for i in 1 : Threads. nthreads ()
188
+ if i == Threads. nthreads ()
189
+ I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
190
+ else
191
+ I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
192
+ end
193
+ batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
198
194
end
195
+ batch_data = reduce (vcat,batch_data)
199
196
end
200
197
tighten_container_eltype (batch_data)
201
198
end
0 commit comments