@@ -104,9 +104,10 @@ function __solve(prob::AbstractEnsembleProblem,
104
104
return EnsembleSolution (_u,elapsed_time,converged)
105
105
end
106
106
107
- function batch_func (i,prob,alg,I,kwargs... )
107
+ function batch_func (i,prob,alg,I,safetycopy, kwargs... )
108
108
iter = 1
109
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
109
+ _prob = safetycopy ? deepcopy (prob. prob) : prob. prob
110
+ new_prob = prob. prob_func (_prob,i,iter)
110
111
rerun = true
111
112
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
112
113
if ! (typeof (x) <: Tuple )
@@ -118,7 +119,8 @@ function batch_func(i,prob,alg,I,kwargs...)
118
119
rerun = _x[2 ]
119
120
while rerun
120
121
iter += 1
121
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
122
+ _prob = safetycopy ? deepcopy (prob. prob) : prob. prob
123
+ new_prob = prob. prob_func (_prob,i,iter)
122
124
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
123
125
if ! (typeof (x) <: Tuple )
124
126
@warn (" output_func should return (out,rerun). See docs for updated details" )
@@ -131,11 +133,11 @@ function batch_func(i,prob,alg,I,kwargs...)
131
133
_x[1 ]
132
134
end
133
135
134
- function solve_batch (prob,alg,:: EnsembleDistributed ,I,pmap_batch_size,kwargs... )
136
+ function solve_batch (prob,alg,ensemblealg :: EnsembleDistributed ,I,pmap_batch_size,kwargs... )
135
137
wp= CachingPool (workers ())
136
138
batch_data = let
137
139
pmap (wp,I,batch_size= pmap_batch_size) do i
138
- batch_func (i,prob,alg,I,kwargs... )
140
+ batch_func (i,prob,alg,I,ensemblealg . safetycopy, kwargs... )
139
141
end
140
142
end
141
143
map (i-> batch_data[i],1 : length (batch_data))
@@ -150,13 +152,14 @@ function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
150
152
map (i-> batch_data[i],1 : length (batch_data))
151
153
end
152
154
153
- function solve_batch (prob,alg,:: EnsembleThreads ,I,pmap_batch_size,kwargs... )
155
+ function solve_batch (prob,alg,ensemblealg :: EnsembleThreads ,I,pmap_batch_size,kwargs... )
154
156
batch_data = Vector {Any} (undef,length (I))
155
157
let
156
158
Threads. @threads for batch_idx in axes (batch_data, 1 )
157
159
i = I[batch_idx]
158
160
iter = 1
159
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
161
+ _prob = ensemblealg. safetycopy ? deepcopy (prob. prob) : prob. prob
162
+ new_prob = prob. prob_func (_prob,i,iter)
160
163
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
161
164
if ! (typeof (x) <: Tuple )
162
165
@warn (" output_func should return (out,rerun). See docs for updated details" )
@@ -168,7 +171,8 @@ function solve_batch(prob,alg,::EnsembleThreads,I,pmap_batch_size,kwargs...)
168
171
169
172
while rerun
170
173
iter += 1
171
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
174
+ _prob = ensemblealg. safetycopy ? deepcopy (prob. prob) : prob. prob
175
+ new_prob = prob. prob_func (_prob,i,iter)
172
176
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
173
177
if ! (typeof (x) <: Tuple )
174
178
@warn (" output_func should return (out,rerun). See docs for updated details" )
@@ -195,20 +199,21 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...
195
199
else
196
200
I_local = I[(batch_size* (i- 1 )+ 1 ): (batch_size* i)]
197
201
end
198
- thread_monte (prob,I_local,alg,i,kwargs... )
202
+ thread_monte (prob,I_local,alg,i,ensemblealg . safetycopy, kwargs... )
199
203
end
200
204
end
201
205
_batch_data = vector_batch_data_to_arr (batch_data)
202
206
end
203
207
204
- function thread_monte (prob,I,alg,procid,kwargs... )
208
+ function thread_monte (prob,I,alg,procid,safetycopy, kwargs... )
205
209
batch_data = Vector {Any} (undef,length (I))
206
210
let
207
211
j = 0
208
212
Threads. @threads for i in I
209
213
j += 1
210
214
iter = 1
211
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
215
+ _prob = safetycopy ? deepcopy (prob. prob) : prob. prob
216
+ new_prob = prob. prob_func (_prob,i,iter)
212
217
rerun = true
213
218
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
214
219
if ! (typeof (x) <: Tuple )
@@ -220,7 +225,8 @@ function thread_monte(prob,I,alg,procid,kwargs...)
220
225
rerun = _x[2 ]
221
226
while rerun
222
227
iter += 1
223
- new_prob = prob. prob_func (deepcopy (prob. prob),i,iter)
228
+ _prob = safetycopy ? deepcopy (prob. prob) : prob. prob
229
+ new_prob = prob. prob_func (_prob,i,iter)
224
230
x = prob. output_func (solve (new_prob,alg;kwargs... ),i)
225
231
if ! (typeof (x) <: Tuple )
226
232
@warn (" output_func should return (out,rerun). See docs for updated details" )
0 commit comments