@@ -55,6 +55,10 @@ function __solve(prob::AbstractEnsembleProblem,
55
55
else
56
56
@error " parallel_type value not recognized"
57
57
end
58
+ elseif alg isa EnsembleAlgorithm
59
+ # Assume DifferentialEquations.jl is being used, so default alg
60
+ ensemblealg = alg
61
+ alg = nothing
58
62
else
59
63
ensemblealg = EnsembleThreads ()
60
64
end
102
106
103
107
function batch_func (i,prob,alg,I,kwargs... )
104
108
iter = 1
105
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
109
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
110
+ new_prob = prob. prob_func (_prob,i,iter)
106
111
rerun = true
107
112
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
108
113
if ! (typeof (x) <: Tuple )
@@ -114,7 +119,8 @@ function batch_func(i,prob,alg,I,kwargs...)
114
119
rerun = _x[2 ]
115
120
while rerun
116
121
iter += 1
117
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
122
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
123
+ new_prob = prob. prob_func (_prob,i,iter)
118
124
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
119
125
if ! (typeof (x) <: Tuple )
120
126
@warn (" output_func should return (out,rerun). See docs for updated details" )
@@ -127,7 +133,7 @@ function batch_func(i,prob,alg,I,kwargs...)
127
133
_x[1 ]
128
134
end
129
135
130
- function solve_batch (prob,alg,:: EnsembleDistributed ,I,pmap_batch_size,kwargs... )
136
+ function solve_batch (prob,alg,ensemblealg :: EnsembleDistributed ,I,pmap_batch_size,kwargs... )
131
137
wp= CachingPool (workers ())
132
138
batch_data = let
133
139
pmap (wp,I,batch_size= pmap_batch_size) do i
@@ -146,13 +152,14 @@ function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
146
152
map (i-> batch_data[i],1 : length (batch_data))
147
153
end
148
154
149
- function solve_batch (prob,alg,:: EnsembleThreads ,I,pmap_batch_size,kwargs... )
155
+ function solve_batch (prob,alg,ensemblealg :: EnsembleThreads ,I,pmap_batch_size,kwargs... )
150
156
batch_data = Vector {Any} (undef,length (I))
151
157
let
152
158
Threads. @threads for batch_idx in axes (batch_data, 1 )
153
159
i = I[batch_idx]
154
160
iter = 1
155
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
161
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
162
+ new_prob = prob. prob_func (_prob,i,iter)
156
163
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
157
164
if ! (typeof (x) <: Tuple )
158
165
@warn (" output_func should return (out,rerun). See docs for updated details" )
@@ -164,7 +171,8 @@ function solve_batch(prob,alg,::EnsembleThreads,I,pmap_batch_size,kwargs...)
164
171
165
172
while rerun
166
173
iter += 1
167
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
174
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
175
+ new_prob = prob. prob_func (_prob,i,iter)
168
176
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
169
177
if ! (typeof (x) <: Tuple )
170
178
@warn (" output_func should return (out,rerun). See docs for updated details" )
@@ -182,23 +190,30 @@ end
182
190
183
191
function solve_batch (prob,alg,:: EnsembleSplitThreads ,I,pmap_batch_size,kwargs... )
184
192
wp= CachingPool (workers ())
193
+ N = nworkers ()
194
+ batch_size = length (I)÷ N
185
195
batch_data = let
186
- pmap (wp,1 : nprocs (),batch_size= pmap_batch_size) do i
187
- thread_monte (prob,I,alg,i,kwargs... )
196
+ pmap (wp,1 : N,batch_size= pmap_batch_size) do i
197
+ if i == N
198
+ I_local = I[(batch_size* (i- 1 )+ 1 ): end ]
199
+ else
200
+ I_local = I[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
201
+ end
202
+ thread_monte (prob,I_local,alg,i,kwargs... )
188
203
end
189
204
end
190
205
_batch_data = vector_batch_data_to_arr (batch_data)
191
206
end
192
207
193
208
function thread_monte (prob,I,alg,procid,kwargs... )
194
- start = I[1 ]+ (procid- 1 )* length (I)
195
- stop = I[1 ]+ procid* length (I)- 1
196
- portion = start: stop
197
- batch_data = Vector {Any} (undef,length (portion))
209
+ batch_data = Vector {Any} (undef,length (I))
198
210
let
199
- Threads. @threads for i in portion
211
+ j = 0
212
+ Threads. @threads for i in I
213
+ j += 1
200
214
iter = 1
201
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
215
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
216
+ new_prob = prob. prob_func (_prob,i,iter)
202
217
rerun = true
203
218
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
204
219
if ! (typeof (x) <: Tuple )
@@ -210,7 +225,8 @@ function thread_monte(prob,I,alg,procid,kwargs...)
210
225
rerun = _x[2 ]
211
226
while rerun
212
227
iter += 1
213
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
228
+ _prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
229
+ new_prob = prob. prob_func (_prob,i,iter)
214
230
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
215
231
if ! (typeof (x) <: Tuple )
216
232
@warn (" output_func should return (out,rerun). See docs for updated details" )
@@ -220,7 +236,7 @@ function thread_monte(prob,I,alg,procid,kwargs...)
220
236
end
221
237
rerun = _x[2 ]
222
238
end
223
- batch_data[i - start + 1 ] = _x[1 ]
239
+ batch_data[j ] = _x[1 ]
224
240
end
225
241
end
226
242
batch_data
0 commit comments