Skip to content

Commit 891aa5e

Browse files
try to do nice things to inference and hope it does nice things for you
1 parent 558f970 commit 891aa5e

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function __solve(prob::AbstractEnsembleProblem,
8585
num_batches * batch_size != trajectories && (num_batches += 1)
8686

8787
function batch_function(I)
88-
batch_data = solve_batch(prob,alg,ensemblealg,I,pmap_batch_size,kwargs...)
88+
batch_data = solve_batch(prob,alg,ensemblealg,I,pmap_batch_size;kwargs...)
8989
end
9090

9191
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
@@ -127,7 +127,7 @@ function __solve(prob::AbstractEnsembleProblem,
127127
return EnsembleSolution(u,elapsed_time,converged)
128128
end
129129

130-
function batch_func(i,prob,alg,I,kwargs...)
130+
function batch_func(i,prob,alg;kwargs...)
131131
iter = 1
132132
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
133133
new_prob = prob.prob_func(_prob,i,iter)
@@ -156,23 +156,19 @@ function batch_func(i,prob,alg,I,kwargs...)
156156
_x[1]
157157
end
158158

159-
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,I,pmap_batch_size,kwargs...)
159+
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,I,pmap_batch_size;kwargs...)
160160
wp=CachingPool(workers())
161-
batch_data = let
162-
pmap(wp,I,batch_size=pmap_batch_size) do i
163-
batch_func(i,prob,alg,I,kwargs...)
164-
end
161+
batch_data = pmap(wp,I,batch_size=pmap_batch_size) do i
162+
batch_func(i,prob,alg;kwargs...)
165163
end
166-
map(i->batch_data[i],1:length(batch_data))
164+
map(identity,batch_data)
167165
end
168166

169-
function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
170-
batch_data = let
171-
map(I) do i
172-
batch_func(i,prob,alg,I,kwargs...)
173-
end
167+
function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size;kwargs...)
168+
batch_data = map(I) do i
169+
batch_func(i,prob,alg;kwargs...)
174170
end
175-
map(i->batch_data[i],1:length(batch_data))
171+
batch_data
176172
end
177173

178174
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwargs...)
@@ -213,7 +209,7 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwa
213209
batch_data
214210
end
215211

216-
function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...)
212+
function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size;kwargs...)
217213
wp=CachingPool(workers())
218214
N = nworkers()
219215
batch_size = length(I)÷N
@@ -224,13 +220,13 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...
224220
else
225221
I_local = I[(batch_size*(i-1)+1):(batch_size*i)]
226222
end
227-
thread_monte(prob,I_local,alg,i,kwargs...)
223+
thread_monte(prob,I_local,alg,i;kwargs...)
228224
end
229225
end
230226
_batch_data = vector_batch_data_to_arr(batch_data)
231227
end
232228

233-
function thread_monte(prob,I,alg,procid,kwargs...)
229+
function thread_monte(prob,I,alg,procid;kwargs...)
234230
batch_data = Vector{Any}(undef,length(I))
235231
let
236232
Threads.@threads for j in 1:length(I)

0 commit comments

Comments
 (0)