Skip to content

Commit a33dbed

Browse files
add safetycopy argument
1 parent c1f4896 commit a33dbed

File tree

3 files changed

+36
-23
lines changed

3 files changed

+36
-23
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ function __solve(prob::AbstractEnsembleProblem,
104104
return EnsembleSolution(_u,elapsed_time,converged)
105105
end
106106

107-
function batch_func(i,prob,alg,I,kwargs...)
107+
function batch_func(i,prob,alg,I,safetycopy,kwargs...)
108108
iter = 1
109-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
109+
_prob = safetycopy ? deepcopy(prob.prob) : prob.prob
110+
new_prob = prob.prob_func(_prob,i,iter)
110111
rerun = true
111112
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
112113
if !(typeof(x) <: Tuple)
@@ -118,7 +119,8 @@ function batch_func(i,prob,alg,I,kwargs...)
118119
rerun = _x[2]
119120
while rerun
120121
iter += 1
121-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
122+
_prob = safetycopy ? deepcopy(prob.prob) : prob.prob
123+
new_prob = prob.prob_func(_prob,i,iter)
122124
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
123125
if !(typeof(x) <: Tuple)
124126
@warn("output_func should return (out,rerun). See docs for updated details")
@@ -131,11 +133,11 @@ function batch_func(i,prob,alg,I,kwargs...)
131133
_x[1]
132134
end
133135

134-
function solve_batch(prob,alg,::EnsembleDistributed,I,pmap_batch_size,kwargs...)
136+
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,I,pmap_batch_size,kwargs...)
135137
wp=CachingPool(workers())
136138
batch_data = let
137139
pmap(wp,I,batch_size=pmap_batch_size) do i
138-
batch_func(i,prob,alg,I,kwargs...)
140+
batch_func(i,prob,alg,I,ensemblealg.safetycopy,kwargs...)
139141
end
140142
end
141143
map(i->batch_data[i],1:length(batch_data))
@@ -150,13 +152,14 @@ function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
150152
map(i->batch_data[i],1:length(batch_data))
151153
end
152154

153-
function solve_batch(prob,alg,::EnsembleThreads,I,pmap_batch_size,kwargs...)
155+
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwargs...)
154156
batch_data = Vector{Any}(undef,length(I))
155157
let
156158
Threads.@threads for batch_idx in axes(batch_data, 1)
157159
i = I[batch_idx]
158160
iter = 1
159-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
161+
_prob = ensemblealg.safetycopy ? deepcopy(prob.prob) : prob.prob
162+
new_prob = prob.prob_func(_prob,i,iter)
160163
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
161164
if !(typeof(x) <: Tuple)
162165
@warn("output_func should return (out,rerun). See docs for updated details")
@@ -168,7 +171,8 @@ function solve_batch(prob,alg,::EnsembleThreads,I,pmap_batch_size,kwargs...)
168171

169172
while rerun
170173
iter += 1
171-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
174+
_prob = ensemblealg.safetycopy ? deepcopy(prob.prob) : prob.prob
175+
new_prob = prob.prob_func(_prob,i,iter)
172176
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
173177
if !(typeof(x) <: Tuple)
174178
@warn("output_func should return (out,rerun). See docs for updated details")
@@ -195,20 +199,21 @@ function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...
195199
else
196200
I_local = I[(batch_size*(i-1)+1):(batch_size*i)]
197201
end
198-
thread_monte(prob,I_local,alg,i,kwargs...)
202+
thread_monte(prob,I_local,alg,i,ensemblealg.safetycopy,kwargs...)
199203
end
200204
end
201205
_batch_data = vector_batch_data_to_arr(batch_data)
202206
end
203207

204-
function thread_monte(prob,I,alg,procid,kwargs...)
208+
function thread_monte(prob,I,alg,procid,safetycopy,kwargs...)
205209
batch_data = Vector{Any}(undef,length(I))
206210
let
207211
j = 0
208212
Threads.@threads for i in I
209213
j += 1
210214
iter = 1
211-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
215+
_prob = safetycopy ? deepcopy(prob.prob) : prob.prob
216+
new_prob = prob.prob_func(_prob,i,iter)
212217
rerun = true
213218
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
214219
if !(typeof(x) <: Tuple)
@@ -220,7 +225,8 @@ function thread_monte(prob,I,alg,procid,kwargs...)
220225
rerun = _x[2]
221226
while rerun
222227
iter += 1
223-
new_prob = prob.prob_func(deepcopy(prob.prob),i,iter)
228+
_prob = safetycopy ? deepcopy(prob.prob) : prob.prob
229+
new_prob = prob.prob_func(_prob,i,iter)
224230
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
225231
if !(typeof(x) <: Tuple)
226232
@warn("output_func should return (out,rerun). See docs for updated details")

src/ensemble/ensemble_problems.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@ struct EnsembleProblem{T,T2,T3,T4,T5} <: AbstractEnsembleProblem
77
output_func::T3
88
reduction::T4
99
u_init::T5
10+
safetycopy::Bool
1011
end
1112

13+
DEFAULT_PROB_FUNC(prob,i,repeat) = prob
14+
DEFAULT_OUTPUT_FUNC(sol,i) = (sol,false)
15+
DEFAULT_REDUCTION(u,data,I) = (append!(u,data),false)
1216
EnsembleProblem(prob;
13-
output_func = (sol,i)-> (sol,false),
14-
prob_func= (prob,i,repeat)->prob,
15-
reduction = (u,data,I)->(append!(u,data),false),
16-
u_init = []) =
17-
EnsembleProblem(prob,prob_func,output_func,reduction,u_init)
17+
output_func = DEFAULT_OUTPUT_FUNC,
18+
prob_func= DEFAULT_PROB_FUNC,
19+
reduction = DEFAULT_REDUCTION,
20+
u_init = [],
21+
safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
22+
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)
1823

1924
EnsembleProblem(;prob,
20-
output_func = (sol,i)-> (sol,false),
21-
prob_func= (prob,i,repeat)->prob,
22-
reduction = (u,data,I)->(append!(u,data),false),
23-
u_init = [], p = nothing) =
24-
EnsembleProblem(prob,prob_func,output_func,reduction,u_init)
25+
output_func = DEFAULT_OUTPUT_FUNC,
26+
prob_func= DEFAULT_PROB_FUNC,
27+
reduction = DEFAULT_REDUCTION,
28+
u_init = [], p = nothing, safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
29+
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)

test/downstream/ensemble.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ reduction = function (u,batch,I)
7373
u,((var(u)/sqrt(last(I)))/mean(u)<0.5) ? true : false
7474
end
7575

76-
prob2 = EnsembleProblem(prob,prob_func=prob_func,output_func=output_func,reduction=reduction,u_init=Vector{Float64}())
76+
prob2 = EnsembleProblem(prob,prob_func=prob_func,output_func=output_func,
77+
reduction=reduction,u_init=Vector{Float64}(),
78+
safetycopy=false)
7779
sim = solve(prob2,Tsit5(),trajectories=10000,batch_size=20)
7880
@test sim.converged == true
7981

0 commit comments

Comments
 (0)