@@ -168,63 +168,31 @@ end
168
168
169
169
function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,II,pmap_batch_size;kwargs... )
170
170
171
+ if length (II) == 1 || Threads. nthreads () == 1
172
+ return solve_batch (prob,alg,EnsembleSerial (),II,pmap_batch_size;kwargs... )
173
+ end
174
+
171
175
if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
172
176
probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
173
177
else
174
178
probs = prob. prob
175
179
end
176
180
177
- function multithreaded_batch (batch_idx)
178
- i = II[batch_idx]
179
- iter = 1
180
- _prob = if prob. safetycopy
181
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
182
- else
183
- probs isa Vector ? probs[Threads. threadid ()] : probs
184
- end
185
- new_prob = prob. prob_func (_prob,i,iter)
186
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
187
- if ! (typeof (x) <: Tuple )
188
- @warn (" output_func should return (out,rerun). See docs for updated details" )
189
- _x = (x,false )
190
- else
191
- _x = x
192
- end
193
- rerun = _x[2 ]
194
-
195
- while rerun
196
- iter += 1
197
- _prob = if prob. safetycopy
198
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
199
- else
200
- probs isa Vector ? probs[Threads. threadid ()] : probs
201
- end
202
- new_prob = prob. prob_func (_prob,i,iter)
203
- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
204
- if ! (typeof (x) <: Tuple )
205
- @warn (" output_func should return (out,rerun). See docs for updated details" )
206
- _x = (x,false )
207
- else
208
- _x = x
209
- end
210
- rerun = _x[2 ]
211
- end
212
- _x[1 ]
213
- end
214
-
215
181
# batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216
- batch_data = Vector {Any} (undef,length (II))
182
+
183
+ batch_data = Vector {Any} (undef,Threads. nthreads ())
184
+ batch_size = length (II)÷ Threads. nthreads ()
217
185
218
186
let
219
- if length (II) == 1 || Threads. nthreads () == 1
220
- for batch_idx in axes (batch_data, 1 )
221
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
222
- end
223
- else
224
- Threads. @threads for batch_idx in axes (batch_data, 1 )
225
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
226
- end
187
+ Threads. @threads for i in 1 : Threads. nthreads ()
188
+ if i == Threads. nthreads ()
189
+ I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
190
+ else
191
+ I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
192
+ end
193
+ batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
227
194
end
195
+ batch_data = reduce (vcat,batch_data)
228
196
end
229
197
tighten_container_eltype (batch_data)
230
198
end
@@ -240,71 +208,8 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
240
208
else
241
209
I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
242
210
end
243
- thread_monte (prob,I_local, alg,i ;kwargs... )
211
+ solve_batch (prob,alg,EnsembleThreads (),I_local,pmap_batch_size ;kwargs... )
244
212
end
245
213
end
246
214
reduce (vcat,batch_data)
247
215
end
248
-
249
- function thread_monte (prob,II,alg,procid;kwargs... )
250
-
251
- if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
252
- probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
253
- else
254
- probs = prob. prob
255
- end
256
-
257
- function multithreaded_batch (j)
258
- i = II[j]
259
- iter = 1
260
- _prob = if prob. safetycopy
261
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
262
- else
263
- probs isa Vector ? probs[Threads. threadid ()] : probs
264
- end
265
- new_prob = prob. prob_func (_prob,i,iter)
266
- rerun = true
267
- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
268
- if ! (typeof (x) <: Tuple )
269
- @warn (" output_func should return (out,rerun). See docs for updated details" )
270
- _x = (x,false )
271
- else
272
- _x = x
273
- end
274
- rerun = _x[2 ]
275
- while rerun
276
- iter += 1
277
- _prob = if prob. safetycopy
278
- probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
279
- else
280
- probs isa Vector ? probs[Threads. threadid ()] : probs
281
- end
282
- new_prob = prob. prob_func (_prob,i,iter)
283
- x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
284
- if ! (typeof (x) <: Tuple )
285
- @warn (" output_func should return (out,rerun). See docs for updated details" )
286
- _x = (x,false )
287
- else
288
- _x = x
289
- end
290
- rerun = _x[2 ]
291
- end
292
- _x[1 ]
293
- end
294
-
295
- # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296
- batch_data = Vector {Any} (undef,length (II))
297
-
298
- let
299
- if length (II) == 1 || Threads. nthreads () == 1
300
- for batch_idx in axes (batch_data, 1 )
301
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
302
- end
303
- else
304
- Threads. @threads for batch_idx in axes (batch_data, 1 )
305
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
306
- end
307
- end
308
- end
309
- tighten_container_eltype (batch_data)
310
- end
0 commit comments