Skip to content

Commit 45e4cb0

Browse files
handle the other multithreaded inference case
1 parent 58aa699 commit 45e4cb0

File tree

1 file changed

+38
-31
lines changed

1 file changed

+38
-31
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,10 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size;kwa
197197
end
198198

199199
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I))})}(undef,length(I))
200-
Threads.@threads for batch_idx in axes(batch_data, 1)
201-
batch_data[batch_idx] = multithreaded_batch(batch_idx)
200+
let
201+
Threads.@threads for batch_idx in axes(batch_data, 1)
202+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
203+
end
202204
end
203205
batch_data
204206
end
@@ -217,40 +219,45 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size;kwargs...
217219
thread_monte(prob,I_local,alg,i;kwargs...)
218220
end
219221
end
220-
_batch_data = vector_batch_data_to_arr(batch_data)
222+
reduce(vcat,batch_data)
221223
end
222224

223225
function thread_monte(prob,I,alg,procid;kwargs...)
224-
batch_data = Vector{Any}(undef,length(I))
226+
function multithreaded_batch(j)
227+
i = I[j]
228+
iter = 1
229+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
230+
new_prob = prob.prob_func(_prob,i,iter)
231+
rerun = true
232+
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
233+
if !(typeof(x) <: Tuple)
234+
@warn("output_func should return (out,rerun). See docs for updated details")
235+
_x = (x,false)
236+
else
237+
_x = x
238+
end
239+
rerun = _x[2]
240+
while rerun
241+
iter += 1
242+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
243+
new_prob = prob.prob_func(_prob,i,iter)
244+
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
245+
if !(typeof(x) <: Tuple)
246+
@warn("output_func should return (out,rerun). See docs for updated details")
247+
_x = (x,false)
248+
else
249+
_x = x
250+
end
251+
rerun = _x[2]
252+
end
253+
_x[1]
254+
end
255+
256+
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I))})}(undef,length(I))
257+
225258
let
226259
Threads.@threads for j in 1:length(I)
227-
i = I[j]
228-
iter = 1
229-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
230-
new_prob = prob.prob_func(_prob,i,iter)
231-
rerun = true
232-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
233-
if !(typeof(x) <: Tuple)
234-
@warn("output_func should return (out,rerun). See docs for updated details")
235-
_x = (x,false)
236-
else
237-
_x = x
238-
end
239-
rerun = _x[2]
240-
while rerun
241-
iter += 1
242-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
243-
new_prob = prob.prob_func(_prob,i,iter)
244-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
245-
if !(typeof(x) <: Tuple)
246-
@warn("output_func should return (out,rerun). See docs for updated details")
247-
_x = (x,false)
248-
else
249-
_x = x
250-
end
251-
rerun = _x[2]
252-
end
253-
batch_data[j] = _x[1]
260+
batch_data[j] = multithreaded_batch(j)
254261
end
255262
end
256263
batch_data

0 commit comments

Comments
 (0)