Skip to content

Commit eb35f47

Browse files
reduce threading overheads
1 parent 03be321 commit eb35f47

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,42 @@ function __solve(prob::AbstractEnsembleProblem,
8484
num_batches = trajectories ÷ batch_size
8585
num_batches * batch_size != trajectories && (num_batches += 1)
8686

87-
u = deepcopy(prob.u_init)
87+
function batch_function(I)
88+
batch_data = solve_batch(prob,alg,ensemblealg,I,pmap_batch_size,kwargs...)
89+
end
90+
91+
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
92+
elapsed_time = @elapsed batch_data = batch_function(1:trajectories)
93+
return EnsembleSolution(batch_data,elapsed_time,true)
94+
end
95+
96+
if prob.u_init === nothing && prob.reduction === DEFAULT_REDUCTION
97+
batchrt = Core.Compiler.return_type(batch_function,Tuple{UnitRange{Int64}})
98+
u = Vector{batchrt}(undef,0)
99+
else
100+
u = []
101+
end
102+
88103
converged = false
104+
89105
elapsed_time = @elapsed for i in 1:num_batches
90106
if i == num_batches
91107
I = (batch_size*(i-1)+1):trajectories
92108
else
93109
I = (batch_size*(i-1)+1):batch_size*i
94110
end
95-
batch_data = solve_batch(prob,alg,ensemblealg,I,pmap_batch_size,kwargs...)
111+
batch_data = batch_function(I)
96112
u,converged = prob.reduction(u,batch_data,I)
97113
converged && break
98114
end
115+
99116
if typeof(u) <: Vector{Any}
100117
_u = map(i->u[i],1:length(u))
101118
else
102119
_u = u
103120
end
104-
return EnsembleSolution(_u,elapsed_time,converged)
121+
122+
return EnsembleSolution(u,elapsed_time,converged)
105123
end
106124

107125
function batch_func(i,prob,alg,I,kwargs...)
@@ -153,11 +171,22 @@ function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
153171
end
154172

155173
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwargs...)
156-
batch_data = Vector{Any}(undef,length(I))
157-
let
158-
Threads.@threads for batch_idx in axes(batch_data, 1)
159-
i = I[batch_idx]
160-
iter = 1
174+
function multithreaded_batch(batch_idx)
175+
i = I[batch_idx]
176+
iter = 1
177+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
178+
new_prob = prob.prob_func(_prob,i,iter)
179+
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
180+
if !(typeof(x) <: Tuple)
181+
@warn("output_func should return (out,rerun). See docs for updated details")
182+
_x = (x,false)
183+
else
184+
_x = x
185+
end
186+
rerun = _x[2]
187+
188+
while rerun
189+
iter += 1
161190
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
162191
new_prob = prob.prob_func(_prob,i,iter)
163192
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
@@ -168,24 +197,15 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwa
168197
_x = x
169198
end
170199
rerun = _x[2]
171-
172-
while rerun
173-
iter += 1
174-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
175-
new_prob = prob.prob_func(_prob,i,iter)
176-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
177-
if !(typeof(x) <: Tuple)
178-
@warn("output_func should return (out,rerun). See docs for updated details")
179-
_x = (x,false)
180-
else
181-
_x = x
182-
end
183-
rerun = _x[2]
184-
end
185-
batch_data[batch_idx] = _x[1]
186200
end
201+
_x[1]
187202
end
188-
map(i->batch_data[i],1:length(batch_data))
203+
204+
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I))})}(undef,length(I))
205+
Threads.@threads for batch_idx in axes(batch_data, 1)
206+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
207+
end
208+
batch_data
189209
end
190210

191211
function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...)

src/ensemble/ensemble_problems.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ end
1212

1313
DEFAULT_PROB_FUNC(prob,i,repeat) = prob
1414
DEFAULT_OUTPUT_FUNC(sol,i) = (sol,false)
15-
DEFAULT_REDUCTION(u,data,I) = (append!(u,data),false)
15+
DEFAULT_REDUCTION(u,data,I) = (reduce(vcat,(u,data)),false)
1616
EnsembleProblem(prob;
1717
output_func = DEFAULT_OUTPUT_FUNC,
1818
prob_func= DEFAULT_PROB_FUNC,
1919
reduction = DEFAULT_REDUCTION,
20-
u_init = [],
20+
u_init = nothing,
2121
safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
2222
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)
2323

2424
EnsembleProblem(;prob,
2525
output_func = DEFAULT_OUTPUT_FUNC,
2626
prob_func= DEFAULT_PROB_FUNC,
2727
reduction = DEFAULT_REDUCTION,
28-
u_init = [], p = nothing, safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
28+
u_init = nothing, p = nothing, safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
2929
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)

0 commit comments

Comments
 (0)