@@ -174,56 +174,27 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
174
174
probs = prob. prob
175
175
end
176
176
177
- function multithreaded_batch (batch_idx)
178
- i = II[batch_idx]
179
- iter = 1
180
- _prob = if prob. safetycopy
181
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
182
- else
183
- probs isa Vector ? probs[Threads. threadid ()] : probs
184
- end
185
- new_prob = prob. prob_func (_prob,i,iter)
186
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
187
- if ! (typeof (x) <: Tuple )
188
- @warn (" output_func should return (out,rerun). See docs for updated details" )
189
- _x = (x,false )
190
- else
191
- _x = x
192
- end
193
- rerun = _x[2 ]
194
-
195
- while rerun
196
- iter += 1
197
- _prob = if prob. safetycopy
198
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
199
- else
200
- probs isa Vector ? probs[Threads. threadid ()] : probs
201
- end
202
- new_prob = prob. prob_func (_prob,i,iter)
203
- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
204
- if ! (typeof (x) <: Tuple )
205
- @warn (" output_func should return (out,rerun). See docs for updated details" )
206
- _x = (x,false )
207
- else
208
- _x = x
209
- end
210
- rerun = _x[2 ]
211
- end
212
- _x[1 ]
213
- end
214
-
215
177
# batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216
- batch_data = Vector {Any} (undef,length (II))
217
178
179
+ local batch_data
218
180
let
219
181
if length (II) == 1 || Threads. nthreads () == 1
182
+ batch_data = Vector {Any} (undef,length (II))
220
183
for batch_idx in axes (batch_data, 1 )
221
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
184
+ batch_data[batch_idx] = multithreaded_batch (batch_idx,probs,alg,II )
222
185
end
223
186
else
224
- Threads. @threads for batch_idx in axes (batch_data, 1 )
225
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
187
+ batch_data = Vector {Any} (undef,Threads. nthreads ())
188
+ batch_size = length (II)÷ Threads. nthreads ()
189
+ Threads. @threads for i in 1 : Threads. nthreads ()
190
+ if i == Threads. nthreads ()
191
+ I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
192
+ else
193
+ I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
194
+ end
195
+ batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
226
196
end
197
+ batch_data = reduce (vcat,batch_data)
227
198
end
228
199
end
229
200
tighten_container_eltype (batch_data)
@@ -240,71 +211,8 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
240
211
else
241
212
I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
242
213
end
243
- thread_monte (prob,I_local, alg,i ;kwargs... )
214
+ solve_batch (prob,alg,EnsembleThreads (),I_local,pmap_batch_size ;kwargs... )
244
215
end
245
216
end
246
217
reduce (vcat,batch_data)
247
218
end
248
-
249
- function thread_monte (prob,II,alg,procid;kwargs... )
250
-
251
- if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
252
- probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
253
- else
254
- probs = prob. prob
255
- end
256
-
257
- function multithreaded_batch (j)
258
- i = II[j]
259
- iter = 1
260
- _prob = if prob. safetycopy
261
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
262
- else
263
- probs isa Vector ? probs[Threads. threadid ()] : probs
264
- end
265
- new_prob = prob. prob_func (_prob,i,iter)
266
- rerun = true
267
- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
268
- if ! (typeof (x) <: Tuple )
269
- @warn (" output_func should return (out,rerun). See docs for updated details" )
270
- _x = (x,false )
271
- else
272
- _x = x
273
- end
274
- rerun = _x[2 ]
275
- while rerun
276
- iter += 1
277
- _prob = if prob. safetycopy
278
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
279
- else
280
- probs isa Vector ? probs[Threads. threadid ()] : probs
281
- end
282
- new_prob = prob. prob_func (_prob,i,iter)
283
- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
284
- if ! (typeof (x) <: Tuple )
285
- @warn (" output_func should return (out,rerun). See docs for updated details" )
286
- _x = (x,false )
287
- else
288
- _x = x
289
- end
290
- rerun = _x[2 ]
291
- end
292
- _x[1 ]
293
- end
294
-
295
- # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296
- batch_data = Vector {Any} (undef,length (II))
297
-
298
- let
299
- if length (II) == 1 || Threads. nthreads () == 1
300
- for batch_idx in axes (batch_data, 1 )
301
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
302
- end
303
- else
304
- Threads. @threads for batch_idx in axes (batch_data, 1 )
305
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
306
- end
307
- end
308
- end
309
- tighten_container_eltype (batch_data)
310
- end
0 commit comments