@@ -84,24 +84,42 @@ function __solve(prob::AbstractEnsembleProblem,
84
84
num_batches = trajectories ÷ batch_size
85
85
num_batches * batch_size != trajectories && (num_batches += 1 )
86
86
87
- u = deepcopy (prob. u_init)
87
+ function batch_function (I)
88
+ batch_data = solve_batch (prob,alg,ensemblealg,I,pmap_batch_size,kwargs... )
89
+ end
90
+
91
+ if num_batches == 1 && prob. reduction === DEFAULT_REDUCTION
92
+ elapsed_time = @elapsed batch_data = batch_function (1 : trajectories)
93
+ return EnsembleSolution (batch_data,elapsed_time,true )
94
+ end
95
+
96
+ if prob. u_init === nothing && prob. reduction === DEFAULT_REDUCTION
97
+ batchrt = Core. Compiler. return_type (batch_function,Tuple{UnitRange{Int64}})
98
+ u = Vector {batchrt} (undef,0 )
99
+ else
100
+ u = []
101
+ end
102
+
88
103
converged = false
104
+
89
105
elapsed_time = @elapsed for i in 1 : num_batches
90
106
if i == num_batches
91
107
I = (batch_size* (i- 1 )+ 1 ): trajectories
92
108
else
93
109
I = (batch_size* (i- 1 )+ 1 ): batch_size* i
94
110
end
95
- batch_data = solve_batch (prob,alg,ensemblealg,I,pmap_batch_size,kwargs ... )
111
+ batch_data = batch_function (I )
96
112
u,converged = prob. reduction (u,batch_data,I)
97
113
converged && break
98
114
end
115
+
99
116
if typeof (u) <: Vector{Any}
100
117
_u = map (i-> u[i],1 : length (u))
101
118
else
102
119
_u = u
103
120
end
104
- return EnsembleSolution (_u,elapsed_time,converged)
121
+
122
+ return EnsembleSolution (u,elapsed_time,converged)
105
123
end
106
124
107
125
function batch_func (i,prob,alg,I,kwargs... )
@@ -153,11 +171,22 @@ function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
153
171
end
154
172
155
173
function solve_batch (prob,alg,ensemblealg:: EnsembleThreads ,I,pmap_batch_size,kwargs... )
156
- batch_data = Vector {Any} (undef,length (I))
157
- let
158
- Threads. @threads for batch_idx in axes (batch_data, 1 )
159
- i = I[batch_idx]
160
- iter = 1
174
+ function multithreaded_batch (batch_idx)
175
+ i = I[batch_idx]
176
+ iter = 1
177
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
178
+ new_prob = prob. prob_func (_prob,i,iter)
179
+ x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
180
+ if ! (typeof (x) <: Tuple )
181
+ @warn (" output_func should return (out,rerun). See docs for updated details" )
182
+ _x = (x,false )
183
+ else
184
+ _x = x
185
+ end
186
+ rerun = _x[2 ]
187
+
188
+ while rerun
189
+ iter += 1
161
190
_prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
162
191
new_prob = prob. prob_func (_prob,i,iter)
163
192
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
@@ -168,24 +197,15 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwa
168
197
_x = x
169
198
end
170
199
rerun = _x[2 ]
171
-
172
- while rerun
173
- iter += 1
174
- _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
175
- new_prob = prob. prob_func (_prob,i,iter)
176
- x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
177
- if ! (typeof (x) <: Tuple )
178
- @warn (" output_func should return (out,rerun). See docs for updated details" )
179
- _x = (x,false )
180
- else
181
- _x = x
182
- end
183
- rerun = _x[2 ]
184
- end
185
- batch_data[batch_idx] = _x[1 ]
186
200
end
201
+ _x[1 ]
187
202
end
188
- map (i-> batch_data[i],1 : length (batch_data))
203
+
204
+ batch_data = Vector {Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(I))})} (undef,length (I))
205
+ Threads. @threads for batch_idx in axes (batch_data, 1 )
206
+ batch_data[batch_idx] = multithreaded_batch (batch_idx)
207
+ end
208
+ batch_data
189
209
end
190
210
191
211
function solve_batch (prob,alg,:: EnsembleSplitThreads ,I,pmap_batch_size,kwargs... )
0 commit comments