@@ -85,7 +85,7 @@ function __solve(prob::AbstractEnsembleProblem,
85
85
num_batches * batch_size != trajectories && (num_batches += 1 )
86
86
87
87
function batch_function (I)
88
- batch_data = solve_batch (prob,alg,ensemblealg,I,pmap_batch_size, kwargs... )
88
+ batch_data = solve_batch (prob,alg,ensemblealg,I,pmap_batch_size; kwargs... )
89
89
end
90
90
91
91
if num_batches == 1 && prob. reduction === DEFAULT_REDUCTION
@@ -127,7 +127,7 @@ function __solve(prob::AbstractEnsembleProblem,
127
127
return EnsembleSolution (u,elapsed_time,converged)
128
128
end
129
129
130
- function batch_func (i,prob,alg,I, kwargs... )
130
+ function batch_func (i,prob,alg; kwargs... )
131
131
iter = 1
132
132
_prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
133
133
new_prob = prob. prob_func (_prob,i,iter)
@@ -156,23 +156,19 @@ function batch_func(i,prob,alg,I,kwargs...)
156
156
_x[1 ]
157
157
end
158
158
159
- function solve_batch (prob,alg,ensemblealg:: EnsembleDistributed ,I,pmap_batch_size, kwargs... )
159
+ function solve_batch (prob,alg,ensemblealg:: EnsembleDistributed ,I,pmap_batch_size; kwargs... )
160
160
wp= CachingPool (workers ())
161
- batch_data = let
162
- pmap (wp,I,batch_size= pmap_batch_size) do i
163
- batch_func (i,prob,alg,I,kwargs... )
164
- end
161
+ batch_data = pmap (wp,I,batch_size= pmap_batch_size) do i
162
+ batch_func (i,prob,alg;kwargs... )
165
163
end
166
- map (i -> batch_data[i], 1 : length ( batch_data) )
164
+ map (identity, batch_data)
167
165
end
168
166
169
- function solve_batch (prob,alg,:: EnsembleSerial ,I,pmap_batch_size,kwargs... )
170
- batch_data = let
171
- map (I) do i
172
- batch_func (i,prob,alg,I,kwargs... )
173
- end
167
+ function solve_batch (prob,alg,:: EnsembleSerial ,I,pmap_batch_size;kwargs... )
168
+ batch_data = map (I) do i
169
+ batch_func (i,prob,alg;kwargs... )
174
170
end
175
- map (i -> batch_data[i], 1 : length (batch_data))
171
+ batch_data
176
172
end
177
173
178
174
function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,I,pmap_batch_size,kwargs... )
@@ -213,7 +209,7 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwa
213
209
batch_data
214
210
end
215
211
216
- function solve_batch (prob,alg,:: EnsembleSplitThreads ,I,pmap_batch_size, kwargs... )
212
+ function solve_batch (prob,alg,:: EnsembleSplitThreads ,I,pmap_batch_size; kwargs... )
217
213
wp= CachingPool (workers ())
218
214
N = nworkers ()
219
215
batch_size = length (I)÷ N
@@ -224,13 +220,13 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...
224
220
else
225
221
I_local = I[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
226
222
end
227
- thread_monte (prob,I_local,alg,i, kwargs... )
223
+ thread_monte (prob,I_local,alg,i; kwargs... )
228
224
end
229
225
end
230
226
_batch_data = vector_batch_data_to_arr (batch_data)
231
227
end
232
228
233
- function thread_monte (prob,I,alg,procid, kwargs... )
229
+ function thread_monte (prob,I,alg,procid; kwargs... )
234
230
batch_data = Vector {Any} (undef,length (I))
235
231
let
236
232
Threads. @threads for j in 1 : length (I)
0 commit comments