@@ -29,8 +29,8 @@ if (kwargs[:parallel_type] == != :none && kwargs[:parallel_type] == != :threads)
29
29
end
30
30
(batch_size != trajectories) && warn("batch_size and reductions are ignored when !collect_result")
31
31
32
- elapsed_time = @elapsed u = DArray((trajectories,)) do I
33
- solve_batch(prob,alg,kwargs[:parallel_type] ==,I [1],pmap_batch_size,kwargs...)
32
+ elapsed_time = @elapsed u = DArray((trajectories,)) do II
33
+ solve_batch(prob,alg,kwargs[:parallel_type] ==,II [1],pmap_batch_size,kwargs...)
34
34
end
35
35
return EnsembleSolution(u,elapsed_time,false)
36
36
=#
@@ -88,8 +88,8 @@ function __solve(prob::AbstractEnsembleProblem,
88
88
num_batches < 1 && error (" trajectories ÷ batch_size cannot be less than 1, got $num_batches " )
89
89
num_batches * batch_size != trajectories && (num_batches += 1 )
90
90
91
- function batch_function (I )
92
- batch_data = solve_batch (prob,alg,ensemblealg,I ,pmap_batch_size;kwargs... )
91
+ function batch_function (II )
92
+ batch_data = solve_batch (prob,alg,ensemblealg,II ,pmap_batch_size;kwargs... )
93
93
end
94
94
95
95
if num_batches == 1 && prob. reduction === DEFAULT_REDUCTION
@@ -99,19 +99,21 @@ function __solve(prob::AbstractEnsembleProblem,
99
99
end
100
100
101
101
converged:: Bool = false
102
+ i = 1
103
+ II = (batch_size* (i- 1 )+ 1 ): batch_size* i
102
104
103
- batch_data = batch_function (I )
105
+ batch_data = batch_function (II )
104
106
u = prob. u_init === nothing ? similar (batch_data, 0 ) : prob. u_init
105
- u,converged = prob. reduction (u,batch_data,I )
107
+ u,converged = prob. reduction (u,batch_data,II )
106
108
elapsed_time = @elapsed for i in 2 : num_batches
107
109
converged && break
108
110
if i == num_batches
109
- I = (batch_size* (i- 1 )+ 1 ): trajectories
111
+ II = (batch_size* (i- 1 )+ 1 ): trajectories
110
112
else
111
- I = (batch_size* (i- 1 )+ 1 ): batch_size* i
113
+ II = (batch_size* (i- 1 )+ 1 ): batch_size* i
112
114
end
113
- batch_data = batch_function (I )
114
- u,converged = prob. reduction (u,batch_data,I )
115
+ batch_data = batch_function (II )
116
+ u,converged = prob. reduction (u,batch_data,II )
115
117
end
116
118
117
119
u = reduce (vcat, u)
@@ -150,24 +152,24 @@ function batch_func(i,prob,alg;kwargs...)
150
152
_x[1 ]
151
153
end
152
154
153
- function solve_batch (prob,alg,ensemblealg:: EnsembleDistributed ,I ,pmap_batch_size;kwargs... )
155
+ function solve_batch (prob,alg,ensemblealg:: EnsembleDistributed ,II ,pmap_batch_size;kwargs... )
154
156
wp= CachingPool (workers ())
155
- batch_data = pmap (wp,I ,batch_size= pmap_batch_size) do i
157
+ batch_data = pmap (wp,II ,batch_size= pmap_batch_size) do i
156
158
batch_func (i,prob,alg;kwargs... )
157
159
end
158
160
map (identity,batch_data)
159
161
end
160
162
161
- function solve_batch (prob,alg,:: EnsembleSerial ,I ,pmap_batch_size;kwargs... )
162
- batch_data = map (I ) do i
163
+ function solve_batch (prob,alg,:: EnsembleSerial ,II ,pmap_batch_size;kwargs... )
164
+ batch_data = map (II ) do i
163
165
batch_func (i,prob,alg;kwargs... )
164
166
end
165
167
batch_data
166
168
end
167
169
168
- function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,I ,pmap_batch_size;kwargs... )
170
+ function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,II ,pmap_batch_size;kwargs... )
169
171
function multithreaded_batch (batch_idx)
170
- i = I [batch_idx]
172
+ i = II [batch_idx]
171
173
iter = 1
172
174
_prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
173
175
new_prob = prob. prob_func (_prob,i,iter)
@@ -196,7 +198,7 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size;kwa
196
198
_x[1 ]
197
199
end
198
200
199
- batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I ))})} (undef,length (I ))
201
+ batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II ))})} (undef,length (II ))
200
202
let
201
203
Threads. @threads for batch_idx in axes (batch_data, 1 )
202
204
batch_data[batch_idx] = multithreaded_batch (batch_idx)
@@ -205,26 +207,26 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size;kwa
205
207
batch_data
206
208
end
207
209
208
- function solve_batch (prob,alg,:: EnsembleSplitThreads ,I ,pmap_batch_size;kwargs... )
210
+ function solve_batch (prob,alg,:: EnsembleSplitThreads ,II ,pmap_batch_size;kwargs... )
209
211
wp= CachingPool (workers ())
210
212
N = nworkers ()
211
- batch_size = length (I )÷ N
213
+ batch_size = length (II )÷ N
212
214
batch_data = let
213
215
pmap (wp,1 : N,batch_size= pmap_batch_size) do i
214
216
if i == N
215
- I_local = I [(batch_size* (i- 1 )+ 1 ): end ]
217
+ I_local = II [(batch_size* (i- 1 )+ 1 ): end ]
216
218
else
217
- I_local = I [(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
219
+ I_local = II [(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
218
220
end
219
221
thread_monte (prob,I_local,alg,i;kwargs... )
220
222
end
221
223
end
222
224
reduce (vcat,batch_data)
223
225
end
224
226
225
- function thread_monte (prob,I ,alg,procid;kwargs... )
227
+ function thread_monte (prob,II ,alg,procid;kwargs... )
226
228
function multithreaded_batch (j)
227
- i = I [j]
229
+ i = II [j]
228
230
iter = 1
229
231
_prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
230
232
new_prob = prob. prob_func (_prob,i,iter)
@@ -253,10 +255,10 @@ function thread_monte(prob,I,alg,procid;kwargs...)
253
255
_x[1 ]
254
256
end
255
257
256
- batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I ))})} (undef,length (I ))
258
+ batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II ))})} (undef,length (II ))
257
259
258
260
let
259
- Threads. @threads for j in 1 : length (I )
261
+ Threads. @threads for j in 1 : length (II )
260
262
batch_data[j] = multithreaded_batch (j)
261
263
end
262
264
end
0 commit comments