@@ -174,27 +174,56 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
174
174
probs = prob. prob
175
175
end
176
176
177
+ function multithreaded_batch (batch_idx)
178
+ i = II[batch_idx]
179
+ iter = 1
180
+ _prob = if prob. safetycopy
181
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
182
+ else
183
+ probs isa Vector ? probs[Threads. threadid ()] : probs
184
+ end
185
+ new_prob = prob. prob_func (_prob,i,iter)
186
+ x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
187
+ if ! (typeof (x) <: Tuple )
188
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
189
+ _x = (x,false )
190
+ else
191
+ _x = x
192
+ end
193
+ rerun = _x[2 ]
194
+
195
+ while rerun
196
+ iter += 1
197
+ _prob = if prob. safetycopy
198
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
199
+ else
200
+ probs isa Vector ? probs[Threads. threadid ()] : probs
201
+ end
202
+ new_prob = prob. prob_func (_prob,i,iter)
203
+ x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
204
+ if ! (typeof (x) <: Tuple )
205
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
206
+ _x = (x,false )
207
+ else
208
+ _x = x
209
+ end
210
+ rerun = _x[2 ]
211
+ end
212
+ _x[1 ]
213
+ end
214
+
177
215
# batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216
+ batch_data = Vector {Any} (undef,length (II))
178
217
179
- local batch_data
180
218
let
181
219
if length (II) == 1 || Threads. nthreads () == 1
182
- batch_data = Vector {Any} (undef,length (II))
183
220
for batch_idx in axes (batch_data, 1 )
184
- batch_data[batch_idx] = multithreaded_batch (batch_idx,probs,alg,II )
221
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
185
222
end
186
223
else
187
- batch_data = Vector {Any} (undef,Threads. nthreads ())
188
- batch_size = length (II)÷ Threads. nthreads ()
189
- Threads. @threads for i in 1 : Threads. nthreads ()
190
- if i == Threads. nthreads ()
191
- I_local = II[(batch_size* (i- 1 )+ 1 ): end ]
192
- else
193
- I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
194
- end
195
- batch_data[i] = solve_batch (prob,alg,EnsembleSerial (),I_local,pmap_batch_size;kwargs... )
224
+ Threads. @threads for batch_idx in axes (batch_data, 1 )
225
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
196
226
end
197
- batch_data = reduce (vcat,batch_data)
198
227
end
199
228
end
200
229
tighten_container_eltype (batch_data)
@@ -211,8 +240,71 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
211
240
else
212
241
I_local = II[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
213
242
end
214
- solve_batch (prob,alg, EnsembleThreads (), I_local,pmap_batch_size ;kwargs... )
243
+ thread_monte (prob,I_local,alg,i ;kwargs... )
215
244
end
216
245
end
217
246
reduce (vcat,batch_data)
218
247
end
248
+
249
+ function thread_monte (prob,II,alg,procid;kwargs... )
250
+
251
+ if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
252
+ probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
253
+ else
254
+ probs = prob. prob
255
+ end
256
+
257
+ function multithreaded_batch (j)
258
+ i = II[j]
259
+ iter = 1
260
+ _prob = if prob. safetycopy
261
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
262
+ else
263
+ probs isa Vector ? probs[Threads. threadid ()] : probs
264
+ end
265
+ new_prob = prob. prob_func (_prob,i,iter)
266
+ rerun = true
267
+ x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
268
+ if ! (typeof (x) <: Tuple )
269
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
270
+ _x = (x,false )
271
+ else
272
+ _x = x
273
+ end
274
+ rerun = _x[2 ]
275
+ while rerun
276
+ iter += 1
277
+ _prob = if prob. safetycopy
278
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
279
+ else
280
+ probs isa Vector ? probs[Threads. threadid ()] : probs
281
+ end
282
+ new_prob = prob. prob_func (_prob,i,iter)
283
+ x = prob. output_func (solve (new_prob,alg;alias_jumps= true ,kwargs... ),i)
284
+ if ! (typeof (x) <: Tuple )
285
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
286
+ _x = (x,false )
287
+ else
288
+ _x = x
289
+ end
290
+ rerun = _x[2 ]
291
+ end
292
+ _x[1 ]
293
+ end
294
+
295
+ # batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296
+ batch_data = Vector {Any} (undef,length (II))
297
+
298
+ let
299
+ if length (II) == 1 || Threads. nthreads () == 1
300
+ for batch_idx in axes (batch_data, 1 )
301
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
302
+ end
303
+ else
304
+ Threads. @threads for batch_idx in axes (batch_data, 1 )
305
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
306
+ end
307
+ end
308
+ end
309
+ tighten_container_eltype (batch_data)
310
+ end
0 commit comments