@@ -29,8 +29,8 @@ if (kwargs[:parallel_type] == != :none && kwargs[:parallel_type] == != :threads)
29
29
end
30
30
(batch_size != trajectories) && warn("batch_size and reductions are ignored when !collect_result")
31
31
32
- elapsed_time = @elapsed u = DArray((trajectories,)) do I
33
- solve_batch(prob,alg,kwargs[:parallel_type] ==,I [1],pmap_batch_size,kwargs...)
32
+ elapsed_time = @elapsed u = DArray((trajectories,)) do II
33
+ solve_batch(prob,alg,kwargs[:parallel_type] ==,II [1],pmap_batch_size,kwargs...)
34
34
end
35
35
return EnsembleSolution(u,elapsed_time,false)
36
36
=#
@@ -75,36 +75,55 @@ function __solve(prob::AbstractEnsembleProblem,
75
75
__solve (prob,alg,ensemblealg;trajectories= trajectories,kwargs... )
76
76
end
77
77
78
+ tighten_container_eltype (u:: Vector{Any} ) = map (identity, u)
79
+ tighten_container_eltype (u) = u
80
+
78
81
function __solve (prob:: AbstractEnsembleProblem ,
79
82
alg:: Union{DEAlgorithm,Nothing} ,
80
83
ensemblealg:: BasicEnsembleAlgorithm ;
81
84
trajectories, batch_size = trajectories,
82
85
pmap_batch_size = batch_size÷ 100 > 0 ? batch_size÷ 100 : 1 , kwargs... )
83
86
84
87
num_batches = trajectories ÷ batch_size
88
+ num_batches < 1 && error (" trajectories ÷ batch_size cannot be less than 1, got $num_batches " )
85
89
num_batches * batch_size != trajectories && (num_batches += 1 )
86
90
87
- u = deepcopy (prob. u_init)
88
- converged = false
89
- elapsed_time = @elapsed for i in 1 : num_batches
91
+ function batch_function (II)
92
+ batch_data = solve_batch (prob,alg,ensemblealg,II,pmap_batch_size;kwargs... )
93
+ end
94
+
95
+ if num_batches == 1 && prob. reduction === DEFAULT_REDUCTION
96
+ elapsed_time = @elapsed u = batch_function (1 : trajectories)
97
+ _u = tighten_container_eltype (u)
98
+ return EnsembleSolution (_u,elapsed_time,true )
99
+ end
100
+
101
+ converged:: Bool = false
102
+ i = 1
103
+ II = (batch_size* (i- 1 )+ 1 ): batch_size* i
104
+
105
+ batch_data = batch_function (II)
106
+ u = prob. u_init === nothing ? similar (batch_data, 0 ) : prob. u_init
107
+ u,converged = prob. reduction (u,batch_data,II)
108
+ elapsed_time = @elapsed for i in 2 : num_batches
109
+ converged && break
90
110
if i == num_batches
91
- I = (batch_size* (i- 1 )+ 1 ): trajectories
111
+ II = (batch_size* (i- 1 )+ 1 ): trajectories
92
112
else
93
- I = (batch_size* (i- 1 )+ 1 ): batch_size* i
113
+ II = (batch_size* (i- 1 )+ 1 ): batch_size* i
94
114
end
95
- batch_data = solve_batch (prob,alg,ensemblealg,I,pmap_batch_size,kwargs... )
96
- u,converged = prob. reduction (u,batch_data,I)
97
- converged && break
98
- end
99
- if typeof (u) <: Vector{Any}
100
- _u = map (i-> u[i],1 : length (u))
101
- else
102
- _u = u
115
+ batch_data = batch_function (II)
116
+ u,converged = prob. reduction (u,batch_data,II)
103
117
end
118
+
119
+ u = reduce (vcat, u)
120
+ _u = tighten_container_eltype (u)
121
+
104
122
return EnsembleSolution (_u,elapsed_time,converged)
123
+
105
124
end
106
125
107
- function batch_func (i,prob,alg,I, kwargs... )
126
+ function batch_func (i,prob,alg; kwargs... )
108
127
iter = 1
109
128
_prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
110
129
new_prob = prob. prob_func (_prob,i,iter)
@@ -133,31 +152,38 @@ function batch_func(i,prob,alg,I,kwargs...)
133
152
_x[1 ]
134
153
end
135
154
136
- function solve_batch (prob,alg,ensemblealg:: EnsembleDistributed ,I ,pmap_batch_size, kwargs... )
155
+ function solve_batch (prob,alg,ensemblealg:: EnsembleDistributed ,II ,pmap_batch_size; kwargs... )
137
156
wp= CachingPool (workers ())
138
- batch_data = let
139
- pmap (wp,I,batch_size= pmap_batch_size) do i
140
- batch_func (i,prob,alg,I,kwargs... )
141
- end
157
+ batch_data = pmap (wp,II,batch_size= pmap_batch_size) do i
158
+ batch_func (i,prob,alg;kwargs... )
142
159
end
143
- map (i -> batch_data[i], 1 : length ( batch_data) )
160
+ map (identity, batch_data)
144
161
end
145
162
146
- function solve_batch (prob,alg,:: EnsembleSerial ,I,pmap_batch_size,kwargs... )
147
- batch_data = let
148
- map (I) do i
149
- batch_func (i,prob,alg,I,kwargs... )
150
- end
163
+ function solve_batch (prob,alg,:: EnsembleSerial ,II,pmap_batch_size;kwargs... )
164
+ batch_data = map (II) do i
165
+ batch_func (i,prob,alg;kwargs... )
151
166
end
152
- map (i -> batch_data[i], 1 : length (batch_data))
167
+ batch_data
153
168
end
154
169
155
- function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,I,pmap_batch_size,kwargs... )
156
- batch_data = Vector {Any} (undef,length (I))
157
- let
158
- Threads. @threads for batch_idx in axes (batch_data, 1 )
159
- i = I[batch_idx]
160
- iter = 1
170
+ function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,II,pmap_batch_size;kwargs... )
171
+ function multithreaded_batch (batch_idx)
172
+ i = II[batch_idx]
173
+ iter = 1
174
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
175
+ new_prob = prob. prob_func (_prob,i,iter)
176
+ x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
177
+ if ! (typeof (x) <: Tuple )
178
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
179
+ _x = (x,false )
180
+ else
181
+ _x = x
182
+ end
183
+ rerun = _x[2 ]
184
+
185
+ while rerun
186
+ iter += 1
161
187
_prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
162
188
new_prob = prob. prob_func (_prob,i,iter)
163
189
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
@@ -168,87 +194,73 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwa
168
194
_x = x
169
195
end
170
196
rerun = _x[2 ]
197
+ end
198
+ _x[1 ]
199
+ end
171
200
172
- while rerun
173
- iter += 1
174
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
175
- new_prob = prob. prob_func (_prob,i,iter)
176
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
177
- if ! (typeof (x) <: Tuple )
178
- @warn (" output_func should return (out,rerun). See docs for updated details" )
179
- _x = (x,false )
180
- else
181
- _x = x
182
- end
183
- rerun = _x[2 ]
184
- end
185
- batch_data[batch_idx] = _x[1 ]
201
+ batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})} (undef,length (II))
202
+ let
203
+ Threads. @threads for batch_idx in axes (batch_data, 1 )
204
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
186
205
end
187
206
end
188
- map (i -> batch_data[i], 1 : length (batch_data))
207
+ batch_data
189
208
end
190
209
191
- function solve_batch (prob,alg,:: EnsembleSplitThreads ,I ,pmap_batch_size, kwargs... )
210
+ function solve_batch (prob,alg,:: EnsembleSplitThreads ,II ,pmap_batch_size; kwargs... )
192
211
wp= CachingPool (workers ())
193
212
N = nworkers ()
194
- batch_size = length (I )÷ N
213
+ batch_size = length (II )÷ N
195
214
batch_data = let
196
215
pmap (wp,1 : N,batch_size= pmap_batch_size) do i
197
216
if i == N
198
- I_local = I [(batch_size* (i- 1 )+ 1 ): end ]
217
+ I_local = II [(batch_size* (i- 1 )+ 1 ): end ]
199
218
else
200
- I_local = I [(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
219
+ I_local = II [(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
201
220
end
202
- thread_monte (prob,I_local,alg,i, kwargs... )
221
+ thread_monte (prob,I_local,alg,i; kwargs... )
203
222
end
204
223
end
205
- _batch_data = vector_batch_data_to_arr ( batch_data)
224
+ reduce (vcat, batch_data)
206
225
end
207
226
208
- function thread_monte (prob,I,alg,procid,kwargs... )
209
- batch_data = Vector {Any} (undef,length (I))
210
- let
211
- Threads. @threads for j in 1 : length (I)
212
- i = I[j]
213
- iter = 1
214
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
215
- new_prob = prob. prob_func (_prob,i,iter)
216
- rerun = true
217
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
218
- if ! (typeof (x) <: Tuple )
219
- @warn (" output_func should return (out,rerun). See docs for updated details" )
220
- _x = (x,false )
221
- else
222
- _x = x
223
- end
224
- rerun = _x[2 ]
225
- while rerun
226
- iter += 1
227
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
228
- new_prob = prob. prob_func (_prob,i,iter)
229
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
230
- if ! (typeof (x) <: Tuple )
231
- @warn (" output_func should return (out,rerun). See docs for updated details" )
232
- _x = (x,false )
233
- else
234
- _x = x
235
- end
236
- rerun = _x[2 ]
237
- end
238
- batch_data[j] = _x[1 ]
227
+ function thread_monte (prob,II,alg,procid;kwargs... )
228
+ function multithreaded_batch (j)
229
+ i = II[j]
230
+ iter = 1
231
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
232
+ new_prob = prob. prob_func (_prob,i,iter)
233
+ rerun = true
234
+ x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
235
+ if ! (typeof (x) <: Tuple )
236
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
237
+ _x = (x,false )
238
+ else
239
+ _x = x
239
240
end
241
+ rerun = _x[2 ]
242
+ while rerun
243
+ iter += 1
244
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
245
+ new_prob = prob. prob_func (_prob,i,iter)
246
+ x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
247
+ if ! (typeof (x) <: Tuple )
248
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
249
+ _x = (x,false )
250
+ else
251
+ _x = x
252
+ end
253
+ rerun = _x[2 ]
254
+ end
255
+ _x[1 ]
240
256
end
241
- batch_data
242
- end
243
257
244
- function vector_batch_data_to_arr (batch_data)
245
- _batch_data = Vector {Any} (undef,sum ((length (x) for x in batch_data)))
246
- idx = 0
247
- @inbounds for a in batch_data
248
- for x in a
249
- idx += 1
250
- _batch_data[idx] = x
258
+ batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})} (undef,length (II))
259
+
260
+ let
261
+ Threads. @threads for j in 1 : length (II)
262
+ batch_data[j] = multithreaded_batch (j)
251
263
end
252
264
end
253
- map (i -> _batch_data[i], 1 : length (_batch_data))
265
+ batch_data
254
266
end
0 commit comments