Skip to content

Commit 0ed23b0

Browse files
fix indexing and don't use I
1 parent 6d90a23 commit 0ed23b0

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ if (kwargs[:parallel_type] == != :none && kwargs[:parallel_type] == != :threads)
2929
end
3030
(batch_size != trajectories) && warn("batch_size and reductions are ignored when !collect_result")
3131
32-
elapsed_time = @elapsed u = DArray((trajectories,)) do I
33-
solve_batch(prob,alg,kwargs[:parallel_type] ==,I[1],pmap_batch_size,kwargs...)
32+
elapsed_time = @elapsed u = DArray((trajectories,)) do II
33+
solve_batch(prob,alg,kwargs[:parallel_type] ==,II[1],pmap_batch_size,kwargs...)
3434
end
3535
return EnsembleSolution(u,elapsed_time,false)
3636
=#
@@ -88,8 +88,8 @@ function __solve(prob::AbstractEnsembleProblem,
8888
num_batches < 1 && error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
8989
num_batches * batch_size != trajectories && (num_batches += 1)
9090

91-
function batch_function(I)
92-
batch_data = solve_batch(prob,alg,ensemblealg,I,pmap_batch_size;kwargs...)
91+
function batch_function(II)
92+
batch_data = solve_batch(prob,alg,ensemblealg,II,pmap_batch_size;kwargs...)
9393
end
9494

9595
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
@@ -99,19 +99,21 @@ function __solve(prob::AbstractEnsembleProblem,
9999
end
100100

101101
converged::Bool = false
102+
i = 1
103+
II = (batch_size*(i-1)+1):batch_size*i
102104

103-
batch_data = batch_function(I)
105+
batch_data = batch_function(II)
104106
u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
105-
u,converged = prob.reduction(u,batch_data,I)
107+
u,converged = prob.reduction(u,batch_data,II)
106108
elapsed_time = @elapsed for i in 2:num_batches
107109
converged && break
108110
if i == num_batches
109-
I = (batch_size*(i-1)+1):trajectories
111+
II = (batch_size*(i-1)+1):trajectories
110112
else
111-
I = (batch_size*(i-1)+1):batch_size*i
113+
II = (batch_size*(i-1)+1):batch_size*i
112114
end
113-
batch_data = batch_function(I)
114-
u,converged = prob.reduction(u,batch_data,I)
115+
batch_data = batch_function(II)
116+
u,converged = prob.reduction(u,batch_data,II)
115117
end
116118

117119
u = reduce(vcat, u)
@@ -150,24 +152,24 @@ function batch_func(i,prob,alg;kwargs...)
150152
_x[1]
151153
end
152154

153-
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,I,pmap_batch_size;kwargs...)
155+
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,II,pmap_batch_size;kwargs...)
154156
wp=CachingPool(workers())
155-
batch_data = pmap(wp,I,batch_size=pmap_batch_size) do i
157+
batch_data = pmap(wp,II,batch_size=pmap_batch_size) do i
156158
batch_func(i,prob,alg;kwargs...)
157159
end
158160
map(identity,batch_data)
159161
end
160162

161-
function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size;kwargs...)
162-
batch_data = map(I) do i
163+
function solve_batch(prob,alg,::EnsembleSerial,II,pmap_batch_size;kwargs...)
164+
batch_data = map(II) do i
163165
batch_func(i,prob,alg;kwargs...)
164166
end
165167
batch_data
166168
end
167169

168-
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size;kwargs...)
170+
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kwargs...)
169171
function multithreaded_batch(batch_idx)
170-
i = I[batch_idx]
172+
i = II[batch_idx]
171173
iter = 1
172174
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
173175
new_prob = prob.prob_func(_prob,i,iter)
@@ -196,7 +198,7 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size;kwa
196198
_x[1]
197199
end
198200

199-
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I))})}(undef,length(I))
201+
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
200202
let
201203
Threads.@threads for batch_idx in axes(batch_data, 1)
202204
batch_data[batch_idx] = multithreaded_batch(batch_idx)
@@ -205,26 +207,26 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size;kwa
205207
batch_data
206208
end
207209

208-
function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size;kwargs...)
210+
function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs...)
209211
wp=CachingPool(workers())
210212
N = nworkers()
211-
batch_size = length(I)÷N
213+
batch_size = length(II)÷N
212214
batch_data = let
213215
pmap(wp,1:N,batch_size=pmap_batch_size) do i
214216
if i == N
215-
I_local = I[(batch_size*(i-1)+1):end]
217+
I_local = II[(batch_size*(i-1)+1):end]
216218
else
217-
I_local = I[(batch_size*(i-1)+1):(batch_size*i)]
219+
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
218220
end
219221
thread_monte(prob,I_local,alg,i;kwargs...)
220222
end
221223
end
222224
reduce(vcat,batch_data)
223225
end
224226

225-
function thread_monte(prob,I,alg,procid;kwargs...)
227+
function thread_monte(prob,II,alg,procid;kwargs...)
226228
function multithreaded_batch(j)
227-
i = I[j]
229+
i = II[j]
228230
iter = 1
229231
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
230232
new_prob = prob.prob_func(_prob,i,iter)
@@ -253,10 +255,10 @@ function thread_monte(prob,I,alg,procid;kwargs...)
253255
_x[1]
254256
end
255257

256-
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I))})}(undef,length(I))
258+
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
257259

258260
let
259-
Threads.@threads for j in 1:length(I)
261+
Threads.@threads for j in 1:length(II)
260262
batch_data[j] = multithreaded_batch(j)
261263
end
262264
end

0 commit comments

Comments
 (0)