Skip to content

Commit b1c553c

Browse files
Merge pull request #511 from SciML/myb/inference
Inference fix
2 parents 03be321 + 4820c8a commit b1c553c

File tree

6 files changed

+180
-107
lines changed

6 files changed

+180
-107
lines changed

src/ensemble/basic_ensemble_solve.jl

Lines changed: 108 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ if (kwargs[:parallel_type] == != :none && kwargs[:parallel_type] == != :threads)
2929
end
3030
(batch_size != trajectories) && warn("batch_size and reductions are ignored when !collect_result")
3131
32-
elapsed_time = @elapsed u = DArray((trajectories,)) do I
33-
solve_batch(prob,alg,kwargs[:parallel_type] ==,I[1],pmap_batch_size,kwargs...)
32+
elapsed_time = @elapsed u = DArray((trajectories,)) do II
33+
solve_batch(prob,alg,kwargs[:parallel_type] ==,II[1],pmap_batch_size,kwargs...)
3434
end
3535
return EnsembleSolution(u,elapsed_time,false)
3636
=#
@@ -75,36 +75,55 @@ function __solve(prob::AbstractEnsembleProblem,
7575
__solve(prob,alg,ensemblealg;trajectories=trajectories,kwargs...)
7676
end
7777

78+
tighten_container_eltype(u::Vector{Any}) = map(identity, u)
79+
tighten_container_eltype(u) = u
80+
7881
function __solve(prob::AbstractEnsembleProblem,
7982
alg::Union{DEAlgorithm,Nothing},
8083
ensemblealg::BasicEnsembleAlgorithm;
8184
trajectories, batch_size = trajectories,
8285
pmap_batch_size = batch_size÷100 > 0 ? batch_size÷100 : 1, kwargs...)
8386

8487
num_batches = trajectories ÷ batch_size
88+
num_batches < 1 && error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
8589
num_batches * batch_size != trajectories && (num_batches += 1)
8690

87-
u = deepcopy(prob.u_init)
88-
converged = false
89-
elapsed_time = @elapsed for i in 1:num_batches
91+
function batch_function(II)
92+
batch_data = solve_batch(prob,alg,ensemblealg,II,pmap_batch_size;kwargs...)
93+
end
94+
95+
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
96+
elapsed_time = @elapsed u = batch_function(1:trajectories)
97+
_u = tighten_container_eltype(u)
98+
return EnsembleSolution(_u,elapsed_time,true)
99+
end
100+
101+
converged::Bool = false
102+
i = 1
103+
II = (batch_size*(i-1)+1):batch_size*i
104+
105+
batch_data = batch_function(II)
106+
u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
107+
u,converged = prob.reduction(u,batch_data,II)
108+
elapsed_time = @elapsed for i in 2:num_batches
109+
converged && break
90110
if i == num_batches
91-
I = (batch_size*(i-1)+1):trajectories
111+
II = (batch_size*(i-1)+1):trajectories
92112
else
93-
I = (batch_size*(i-1)+1):batch_size*i
113+
II = (batch_size*(i-1)+1):batch_size*i
94114
end
95-
batch_data = solve_batch(prob,alg,ensemblealg,I,pmap_batch_size,kwargs...)
96-
u,converged = prob.reduction(u,batch_data,I)
97-
converged && break
98-
end
99-
if typeof(u) <: Vector{Any}
100-
_u = map(i->u[i],1:length(u))
101-
else
102-
_u = u
115+
batch_data = batch_function(II)
116+
u,converged = prob.reduction(u,batch_data,II)
103117
end
118+
119+
u = reduce(vcat, u)
120+
_u = tighten_container_eltype(u)
121+
104122
return EnsembleSolution(_u,elapsed_time,converged)
123+
105124
end
106125

107-
function batch_func(i,prob,alg,I,kwargs...)
126+
function batch_func(i,prob,alg;kwargs...)
108127
iter = 1
109128
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
110129
new_prob = prob.prob_func(_prob,i,iter)
@@ -133,31 +152,38 @@ function batch_func(i,prob,alg,I,kwargs...)
133152
_x[1]
134153
end
135154

136-
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,I,pmap_batch_size,kwargs...)
155+
function solve_batch(prob,alg,ensemblealg::EnsembleDistributed,II,pmap_batch_size;kwargs...)
137156
wp=CachingPool(workers())
138-
batch_data = let
139-
pmap(wp,I,batch_size=pmap_batch_size) do i
140-
batch_func(i,prob,alg,I,kwargs...)
141-
end
157+
batch_data = pmap(wp,II,batch_size=pmap_batch_size) do i
158+
batch_func(i,prob,alg;kwargs...)
142159
end
143-
map(i->batch_data[i],1:length(batch_data))
160+
map(identity,batch_data)
144161
end
145162

146-
function solve_batch(prob,alg,::EnsembleSerial,I,pmap_batch_size,kwargs...)
147-
batch_data = let
148-
map(I) do i
149-
batch_func(i,prob,alg,I,kwargs...)
150-
end
163+
function solve_batch(prob,alg,::EnsembleSerial,II,pmap_batch_size;kwargs...)
164+
batch_data = map(II) do i
165+
batch_func(i,prob,alg;kwargs...)
151166
end
152-
map(i->batch_data[i],1:length(batch_data))
167+
batch_data
153168
end
154169

155-
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwargs...)
156-
batch_data = Vector{Any}(undef,length(I))
157-
let
158-
Threads.@threads for batch_idx in axes(batch_data, 1)
159-
i = I[batch_idx]
160-
iter = 1
170+
function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kwargs...)
171+
function multithreaded_batch(batch_idx)
172+
i = II[batch_idx]
173+
iter = 1
174+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
175+
new_prob = prob.prob_func(_prob,i,iter)
176+
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
177+
if !(typeof(x) <: Tuple)
178+
@warn("output_func should return (out,rerun). See docs for updated details")
179+
_x = (x,false)
180+
else
181+
_x = x
182+
end
183+
rerun = _x[2]
184+
185+
while rerun
186+
iter += 1
161187
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
162188
new_prob = prob.prob_func(_prob,i,iter)
163189
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
@@ -168,87 +194,73 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,I,pmap_batch_size,kwa
168194
_x = x
169195
end
170196
rerun = _x[2]
197+
end
198+
_x[1]
199+
end
171200

172-
while rerun
173-
iter += 1
174-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
175-
new_prob = prob.prob_func(_prob,i,iter)
176-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
177-
if !(typeof(x) <: Tuple)
178-
@warn("output_func should return (out,rerun). See docs for updated details")
179-
_x = (x,false)
180-
else
181-
_x = x
182-
end
183-
rerun = _x[2]
184-
end
185-
batch_data[batch_idx] = _x[1]
201+
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
202+
let
203+
Threads.@threads for batch_idx in axes(batch_data, 1)
204+
batch_data[batch_idx] = multithreaded_batch(batch_idx)
186205
end
187206
end
188-
map(i->batch_data[i],1:length(batch_data))
207+
batch_data
189208
end
190209

191-
function solve_batch(prob,alg,::EnsembleSplitThreads,I,pmap_batch_size,kwargs...)
210+
function solve_batch(prob,alg,::EnsembleSplitThreads,II,pmap_batch_size;kwargs...)
192211
wp=CachingPool(workers())
193212
N = nworkers()
194-
batch_size = length(I)÷N
213+
batch_size = length(II)÷N
195214
batch_data = let
196215
pmap(wp,1:N,batch_size=pmap_batch_size) do i
197216
if i == N
198-
I_local = I[(batch_size*(i-1)+1):end]
217+
I_local = II[(batch_size*(i-1)+1):end]
199218
else
200-
I_local = I[(batch_size*(i-1)+1):(batch_size*i)]
219+
I_local = II[(batch_size*(i-1)+1):(batch_size*i)]
201220
end
202-
thread_monte(prob,I_local,alg,i,kwargs...)
221+
thread_monte(prob,I_local,alg,i;kwargs...)
203222
end
204223
end
205-
_batch_data = vector_batch_data_to_arr(batch_data)
224+
reduce(vcat,batch_data)
206225
end
207226

208-
function thread_monte(prob,I,alg,procid,kwargs...)
209-
batch_data = Vector{Any}(undef,length(I))
210-
let
211-
Threads.@threads for j in 1:length(I)
212-
i = I[j]
213-
iter = 1
214-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
215-
new_prob = prob.prob_func(_prob,i,iter)
216-
rerun = true
217-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
218-
if !(typeof(x) <: Tuple)
219-
@warn("output_func should return (out,rerun). See docs for updated details")
220-
_x = (x,false)
221-
else
222-
_x = x
223-
end
224-
rerun = _x[2]
225-
while rerun
226-
iter += 1
227-
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
228-
new_prob = prob.prob_func(_prob,i,iter)
229-
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
230-
if !(typeof(x) <: Tuple)
231-
@warn("output_func should return (out,rerun). See docs for updated details")
232-
_x = (x,false)
233-
else
234-
_x = x
235-
end
236-
rerun = _x[2]
237-
end
238-
batch_data[j] = _x[1]
227+
function thread_monte(prob,II,alg,procid;kwargs...)
228+
function multithreaded_batch(j)
229+
i = II[j]
230+
iter = 1
231+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
232+
new_prob = prob.prob_func(_prob,i,iter)
233+
rerun = true
234+
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
235+
if !(typeof(x) <: Tuple)
236+
@warn("output_func should return (out,rerun). See docs for updated details")
237+
_x = (x,false)
238+
else
239+
_x = x
239240
end
241+
rerun = _x[2]
242+
while rerun
243+
iter += 1
244+
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
245+
new_prob = prob.prob_func(_prob,i,iter)
246+
x = prob.output_func(solve(new_prob,alg;kwargs...),i)
247+
if !(typeof(x) <: Tuple)
248+
@warn("output_func should return (out,rerun). See docs for updated details")
249+
_x = (x,false)
250+
else
251+
_x = x
252+
end
253+
rerun = _x[2]
254+
end
255+
_x[1]
240256
end
241-
batch_data
242-
end
243257

244-
function vector_batch_data_to_arr(batch_data)
245-
_batch_data = Vector{Any}(undef,sum((length(x) for x in batch_data)))
246-
idx = 0
247-
@inbounds for a in batch_data
248-
for x in a
249-
idx += 1
250-
_batch_data[idx] = x
258+
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
259+
260+
let
261+
Threads.@threads for j in 1:length(II)
262+
batch_data[j] = multithreaded_batch(j)
251263
end
252264
end
253-
map(i->_batch_data[i],1:length(_batch_data))
265+
batch_data
254266
end

src/ensemble/ensemble_problems.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@ end
1212

1313
DEFAULT_PROB_FUNC(prob,i,repeat) = prob
1414
DEFAULT_OUTPUT_FUNC(sol,i) = (sol,false)
15-
DEFAULT_REDUCTION(u,data,I) = (append!(u,data),false)
15+
DEFAULT_REDUCTION(u,data,I) = push!(u, data), false
1616
EnsembleProblem(prob;
1717
output_func = DEFAULT_OUTPUT_FUNC,
1818
prob_func= DEFAULT_PROB_FUNC,
1919
reduction = DEFAULT_REDUCTION,
20-
u_init = [],
20+
u_init = nothing,
2121
safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
2222
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)
2323

2424
EnsembleProblem(;prob,
2525
output_func = DEFAULT_OUTPUT_FUNC,
2626
prob_func= DEFAULT_PROB_FUNC,
2727
reduction = DEFAULT_REDUCTION,
28-
u_init = [], p = nothing, safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
28+
u_init = nothing, p = nothing, safetycopy = prob_func !== DEFAULT_PROB_FUNC) =
2929
EnsembleProblem(prob,prob_func,output_func,reduction,u_init,safetycopy)

src/solve.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem,
66
AbstractQuadratureProblem,
77
AbstractSteadyStateProblem,AbstractJumpProblem}
88

9+
has_kwargs(_prob::DEProblem) = has_kwargs(typeof(_prob))
10+
Base.@pure has_kwargs(::Type{T}) where T = :kwargs fieldnames(T)
11+
912
function init_call(_prob,args...;kwargs...)
10-
if :kwargs propertynames(_prob)
13+
if has_kwargs(_prob)
1114
__init(_prob,args...;_prob.kwargs...,kwargs...)
1215
else
1316
__init(_prob,args...;kwargs...)
@@ -34,7 +37,7 @@ function init(prob::DEProblem,args...;kwargs...)
3437
end
3538

3639
function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
37-
if :kwargs propertynames(_prob)
40+
if has_kwargs(_prob)
3841
if merge_callbacks && haskey(_prob.kwargs,:callback) && haskey(kwargs, :callback)
3942
kwargs_temp = NamedTuple{Base.diff_names(Base._nt_names(
4043
values(kwargs)), (:callback,))}(values(kwargs))
@@ -44,10 +47,19 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
4447
kwargs = merge(values(_prob.kwargs), kwargs)
4548
end
4649

47-
logger = get(kwargs, :progress, false) ? default_logger(Logging.current_logger()) : nothing
48-
maybe_with_logger(logger) do
49-
__solve(_prob,args...; kwargs...)
50+
T = Core.Compiler.return_type(__solve,Tuple{typeof(_prob),map(typeof, args)...})
51+
52+
progress = get(kwargs, :progress, false)
53+
if progress
54+
logger = default_logger(Logging.current_logger())
55+
x = maybe_with_logger(logger) do
56+
__solve(_prob,args...; kwargs...)
57+
end
58+
return x::T
59+
else
60+
__solve(_prob,args...; kwargs...)::T
5061
end
62+
5163
end
5264

5365
function solve(prob::DEProblem,args...;kwargs...)

test/downstream/ensemble.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ sim = solve(prob2,Tsit5(),trajectories=10000,batch_size=20)
8080
@test sim.converged == true
8181

8282

83-
Random.seed!(100)
83+
prob_func = function (prob,i,repeat)
84+
ODEProblem(prob.f,(1 + i/100)*prob.u0,prob.tspan,1.01)
85+
end
86+
8487
reduction = function (u,batch,I)
8588
u = append!(u,batch)
8689
u,false
@@ -90,7 +93,6 @@ prob2 = EnsembleProblem(prob,prob_func=prob_func,output_func=output_func,reducti
9093
sim = solve(prob2,Tsit5(),trajectories=100,batch_size=20)
9194
@test sim.converged == false
9295

93-
Random.seed!(100)
9496
reduction = function (u,batch,I)
9597
u+sum(batch),false
9698
end
@@ -105,4 +107,4 @@ output_func = function (sol,i)
105107
end
106108
prob2 = EnsembleProblem(prob,prob_func=prob_func,output_func=output_func)
107109
sim2 = solve(prob2,Tsit5(),trajectories=2)
108-
@test !sim2.converged && typeof(sim2.u) == Vector{SomeUserType}
110+
@test sim2.converged && typeof(sim2.u) == Vector{SomeUserType}

0 commit comments

Comments
 (0)