Skip to content

Commit 6ebad1a

Browse files
Revert "fix multithreaded thread safety and clean up"
This reverts commit 004ac47.
1 parent 004ac47 commit 6ebad1a

File tree

3 files changed

+106
-41
lines changed

3 files changed

+106
-41
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,27 +174,56 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
174174
probs = prob.prob
175175
end
176176

177+
function multithreaded_batch(batch_idx)
178+
i = II[batch_idx]
179+
iter = 1
180+
_prob = if prob.safetycopy
181+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
182+
else
183+
probs isa Vector ? probs[Threads.threadid()] : probs
184+
end
185+
new_prob = prob.prob_func(_prob,i,iter)
186+
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
187+
if !(typeof(x) <: Tuple)
188+
@warn("output_func should return (out,rerun). See docs for updated details")
189+
_x = (x,false)
190+
else
191+
_x = x
192+
end
193+
rerun = _x[2]
194+
195+
while rerun
196+
iter += 1
197+
_prob = if prob.safetycopy
198+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
199+
else
200+
probs isa Vector ? probs[Threads.threadid()] : probs
201+
end
202+
new_prob = prob.prob_func(_prob,i,iter)
203+
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
204+
if !(typeof(x) <: Tuple)
205+
@warn("output_func should return (out,rerun). See docs for updated details")
206+
_x = (x,false)
207+
else
208+
_x = x
209+
end
210+
rerun = _x[2]
211+
end
212+
_x[1]
213+
end
214+
177215
#batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
216+
batch_data = Vector{Any}(undef,length(II))
178217

179-
local batch_data
180218
let
181219
if length(II) == 1 || Threads.nthreads() == 1
182-
batch_data = Vector{Any}(undef,length(II))
183220
for batch_idx in axes(batch_data, 1)
184-
batch_data[batch_idx] = multithreaded_batch(batch_idx,probs,alg,II)
221+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
185222
end
186223
else
187-
batch_data = Vector{Any}(undef,Threads.nthreads())
188-
batch_size = length(II)÷Threads.nthreads()
189-
Threads.@threads for i in 1:Threads.nthreads()
190-
if i == Threads.nthreads()
191-
I_local = II[(batch_size*(i-1)+1):end]
192-
else
193-
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
194-
end
195-
batch_data[i] = solve_batch(prob,alg,EnsembleSerial(),I_local,pmap_batch_size;kwargs...)
224+
Threads.@threads for batch_idx in axes(batch_data, 1)
225+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
196226
end
197-
batch_data = reduce(vcat,batch_data)
198227
end
199228
end
200229
tighten_container_eltype(batch_data)
@@ -211,8 +240,71 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs..
211240
else
212241
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
213242
end
214-
solve_batch(prob,alg,EnsembleThreads(),I_local,pmap_batch_size;kwargs...)
243+
thread_monte(prob,I_local,alg,i;kwargs...)
215244
end
216245
end
217246
reduce(vcat,batch_data)
218247
end
248+
249+
function thread_monte(prob,II,alg,procid;kwargs...)
250+
251+
if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
252+
probs = [deepcopy(prob.prob) for i in 1:Threads.nthreads()]
253+
else
254+
probs = prob.prob
255+
end
256+
257+
function multithreaded_batch(j)
258+
i = II[j]
259+
iter = 1
260+
_prob = if prob.safetycopy
261+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
262+
else
263+
probs isa Vector ? probs[Threads.threadid()] : probs
264+
end
265+
new_prob = prob.prob_func(_prob,i,iter)
266+
rerun = true
267+
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
268+
if !(typeof(x) <: Tuple)
269+
@warn("output_func should return (out,rerun). See docs for updated details")
270+
_x = (x,false)
271+
else
272+
_x = x
273+
end
274+
rerun = _x[2]
275+
while rerun
276+
iter += 1
277+
_prob = if prob.safetycopy
278+
probs isa Vector ? deepcopy(probs[Threads.threadid()]) : probs
279+
else
280+
probs isa Vector ? probs[Threads.threadid()] : probs
281+
end
282+
new_prob = prob.prob_func(_prob,i,iter)
283+
x = prob.output_func(solve(new_prob,alg;alias_jumps=true,kwargs...),i)
284+
if !(typeof(x) <: Tuple)
285+
@warn("output_func should return (out,rerun). See docs for updated details")
286+
_x = (x,false)
287+
else
288+
_x = x
289+
end
290+
rerun = _x[2]
291+
end
292+
_x[1]
293+
end
294+
295+
#batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
296+
batch_data = Vector{Any}(undef,length(II))
297+
298+
let
299+
if length(II) == 1 || Threads.nthreads() == 1
300+
for batch_idx in axes(batch_data, 1)
301+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
302+
end
303+
else
304+
Threads.@threads for batch_idx in axes(batch_data, 1)
305+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
306+
end
307+
end
308+
end
309+
tighten_container_eltype(batch_data)
310+
end

test/downstream/ensemble_thread_safety.jl

Lines changed: 0 additions & 26 deletions
This file was deleted.

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ if !is_APPVEYOR && GROUP == "Downstream"
3838
@time @safetestset "Null Parameters" begin include("downstream/null_params_test.jl") end
3939
@time @safetestset "Ensemble Simulations" begin include("downstream/ensemble.jl") end
4040
@time @safetestset "Ensemble Analysis" begin include("downstream/ensemble_analysis.jl") end
41-
@time @safetestset "Ensemble Thread Safety" begin include("downstream/ensemble_thread_safety.jl") end
4241
@time @safetestset "Inference Tests" begin include("downstream/inference.jl") end
4342
@time @safetestset "Default linsolve with structure" begin include("downstream/default_linsolve_structure.jl") end
4443
@time @safetestset "Callback Merging Tests" begin include("downstream/callback_merging.jl") end

0 commit comments

Comments
 (0)