@@ -168,10 +168,21 @@ function solve_batch(prob,alg,::EnsembleSerial,II,pmap_batch_size;kwargs...)
168
168
end
169
169
170
170
function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,II,pmap_batch_size;kwargs... )
171
+
172
+ if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
173
+ probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
174
+ else
175
+ probs = prob. prob
176
+ end
177
+
171
178
function multithreaded_batch (batch_idx)
172
179
i = II[batch_idx]
173
180
iter = 1
174
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
181
+ _prob = if prob. safetycopy
182
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
183
+ else
184
+ probs isa Vector ? probs[Threads. threadid ()] : probs
185
+ end
175
186
new_prob = prob. prob_func (_prob,i,iter)
176
187
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
177
188
if ! (typeof (x) <: Tuple )
@@ -184,9 +195,13 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
184
195
185
196
while rerun
186
197
iter += 1
187
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
198
+ _prob = if prob. safetycopy
199
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
200
+ else
201
+ probs isa Vector ? probs[Threads. threadid ()] : probs
202
+ end
188
203
new_prob = prob. prob_func (_prob,i,iter)
189
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
204
+ x = prob. output_func (solve (new_prob,alg;alias_jumps = true , kwargs... ),i)
190
205
if ! (typeof (x) <: Tuple )
191
206
@warn (" output_func should return (out,rerun). See docs for updated details" )
192
207
_x = (x,false )
@@ -200,8 +215,14 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
200
215
201
216
batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})} (undef,length (II))
202
217
let
203
- Threads. @threads for batch_idx in axes (batch_data, 1 )
204
- batch_data[batch_idx] = multithreaded_batch (batch_idx)
218
+ if length (II) == 1
219
+ for batch_idx in axes (batch_data, 1 )
220
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
221
+ end
222
+ else
223
+ Threads. @threads for batch_idx in axes (batch_data, 1 )
224
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
225
+ end
205
226
end
206
227
end
207
228
batch_data
@@ -225,13 +246,24 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
225
246
end
226
247
227
248
function thread_monte (prob,II,alg,procid;kwargs... )
249
+
250
+ if typeof (prob. prob) <: AbstractJumpProblem && length (II) != 1
251
+ probs = [deepcopy (prob. prob) for i in 1 : Threads. nthreads ()]
252
+ else
253
+ probs = prob. prob
254
+ end
255
+
228
256
function multithreaded_batch (j)
229
257
i = II[j]
230
258
iter = 1
231
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
259
+ _prob = if prob. safetycopy
260
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
261
+ else
262
+ probs isa Vector ? probs[Threads. threadid ()] : probs
263
+ end
232
264
new_prob = prob. prob_func (_prob,i,iter)
233
265
rerun = true
234
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
266
+ x = prob. output_func (solve (new_prob,alg;alias_jumps = true , kwargs... ),i)
235
267
if ! (typeof (x) <: Tuple )
236
268
@warn (" output_func should return (out,rerun). See docs for updated details" )
237
269
_x = (x,false )
@@ -241,9 +273,13 @@ function thread_monte(prob,II,alg,procid;kwargs...)
241
273
rerun = _x[2 ]
242
274
while rerun
243
275
iter += 1
244
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
276
+ _prob = if prob. safetycopy
277
+ probs isa Vector ? deepcopy (probs[Threads. threadid ()]) : probs
278
+ else
279
+ probs isa Vector ? probs[Threads. threadid ()] : probs
280
+ end
245
281
new_prob = prob. prob_func (_prob,i,iter)
246
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
282
+ x = prob. output_func (solve (new_prob,alg;alias_jumps = true , kwargs... ),i)
247
283
if ! (typeof (x) <: Tuple )
248
284
@warn (" output_func should return (out,rerun). See docs for updated details" )
249
285
_x = (x,false )
@@ -258,8 +294,14 @@ function thread_monte(prob,II,alg,procid;kwargs...)
258
294
batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})} (undef,length (II))
259
295
260
296
let
261
- Threads. @threads for j in 1 : length (II)
262
- batch_data[j] = multithreaded_batch (j)
297
+ if length (II) == 1
298
+ for batch_idx in axes (batch_data, 1 )
299
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
300
+ end
301
+ else
302
+ Threads. @threads for batch_idx in axes (batch_data, 1 )
303
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
304
+ end
263
305
end
264
306
end
265
307
batch_data
0 commit comments