@@ -206,24 +206,24 @@ function bayes_unpack_data(prob, p::AbstractVector{<:Pair})
206206 (pdist, IndexKeyMap(prob, pkeys))
207207end
208208
209- Turing. @model function bayesianODE(prob, t, pdist, pkeys, data, noise_prior)
209+ Turing. @model function bayesianODE(prob, alg, t, pdist, pkeys, data, datamap , noise_prior)
210210 σ ~ noise_prior
211211
212212 pprior ~ product_distribution(pdist)
213213
214214 prob = _remake(prob, (prob. tspan[1 ], t[end ]), pkeys, pprior)
215- sol = solve(prob, saveat = t)
215+ sol = solve(prob, alg, saveat = t)
216216 if ! SciMLBase. successful_retcode(sol)
217217 Turing. DynamicPPL. acclogp!!(__varinfo__, - Inf )
218218 return nothing
219219 end
220220 for i in eachindex(data)
221- data[i]. second ~ MvNormal(sol[data[i] . first] , σ^ 2 * I)
221+ data[i] ~ MvNormal(datamap( sol) , σ^ 2 * I)
222222 end
223223 return nothing
224224end
225225
226- Turing. @model function bayesianODE(prob,
226+ Turing. @model function bayesianODE(prob, alg,
227227 pdist,
228228 pkeys,
229229 ts,
@@ -236,7 +236,7 @@ Turing.@model function bayesianODE(prob,
236236 pprior ~ product_distribution(pdist)
237237
238238 prob = _remake(prob, (prob. tspan[1 ], lastt), pkeys, pprior)
239- sol = solve(prob)
239+ sol = solve(prob, alg )
240240 if ! SciMLBase. successful_retcode(sol)
241241 Turing. DynamicPPL. acclogp!!(__varinfo__, - Inf )
242242 return nothing
@@ -264,18 +264,19 @@ end
264264Base. length(ws:: WeightedSol ) = length(first(ws. sols))
265265Base. size(ws:: WeightedSol ) = (length(first(ws. sols)),)
266266function Base. getindex(ws:: WeightedSol{T} , i:: Int ) where {T}
267- s = zero(T)
268- w = zero(T)
269- for j in eachindex(ws. weights)
267+ s:: T = zero(T)
268+ w:: T = zero(T)
269+ @inbounds for j in eachindex(ws. weights)
270270 w += ws. weights[j]
271271 s += ws. weights[j] * ws. sols[j][i]
272272 end
273273 return s + (one(T) - w) * ws. sols[end ][i]
274274end
275- function WeightedSol(sols, select, weights)
276- T = eltype(weights)
277- s = map(Base. Fix2(getindex, select), sols)
278- WeightedSol{T}(s, weights)
275+ function WeightedSol(sols, select, i:: Int , weights)
276+ s = map(sols, select) do sol, sel
277+ @view(sol[sel. indices[i], :])
278+ end
279+ WeightedSol{eltype(weights)}(s, weights)
279280end
280281function bayes_unpack_data(probs, p:: Tuple{Vararg{<:AbstractVector{<:Pair}}} , data)
281282 pdist, pkeys = bayes_unpack_data(probs, p)
@@ -305,43 +306,46 @@ function flatten(x::Tuple)
305306 reduce(vcat, x), Grouper(map(length, x))
306307end
307308
308- function getsols(probs, probspkeys, ppriors, t:: AbstractArray )
309- map(probs, probspkeys, ppriors) do prob, pkeys, pprior
309+ function getsols(probs, algs, probspkeys, ppriors, t:: AbstractArray )
310+ map(probs, algs, probspkeys, ppriors) do prob, alg , pkeys, pprior
310311 newprob = _remake(prob, (prob. tspan[1 ], t[end ]), pkeys, pprior)
311- solve(newprob, saveat = t)
312+ solve(newprob, alg, saveat = t)
312313 end
313314end
314- function getsols(probs, probspkeys, ppriors, lastt:: Number )
315- map(probs, probspkeys, ppriors) do prob, pkeys, pprior
315+ function getsols(probs, algs, probspkeys, ppriors, lastt:: Number )
316+ map(probs, algs, probspkeys, ppriors) do prob, alg , pkeys, pprior
316317 newprob = _remake(prob, (prob. tspan[1 ], lastt), pkeys, pprior)
317- solve(newprob)
318+ solve(newprob, alg )
318319 end
319320end
320321
321322Turing. @model function ensemblebayesianODE(probs:: Union{Tuple, AbstractVector} ,
323+ algs,
322324 t,
323325 pdist,
324326 grouppriorsfunc,
325327 probspkeys,
326328 data,
329+ datamaps,
327330 noise_prior)
328331 σ ~ noise_prior
329332 ppriors ~ product_distribution(pdist)
330333
331334 Nprobs = length(probs)
332335 Nprobs⁻¹ = inv(Nprobs)
333336 weights ~ MvNormal(Distributions. Fill(Nprobs⁻¹, Nprobs - 1 ), Nprobs⁻¹)
334- sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), t)
337+ sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), t)
335338 if ! all(SciMLBase. successful_retcode, sols)
336339 Turing. DynamicPPL. acclogp!!(__varinfo__, - Inf )
337340 return nothing
338341 end
339342 for i in eachindex(data)
340- data[i]. second ~ MvNormal(WeightedSol(sols, data[i] . first , weights), σ^ 2 * I)
343+ data[i] ~ MvNormal(WeightedSol(sols, datamaps, i , weights), σ^ 2 * I)
341344 end
342345 return nothing
343346end
344347Turing. @model function ensemblebayesianODE(probs:: Union{Tuple, AbstractVector} ,
348+ algs,
345349 pdist,
346350 grouppriorsfunc,
347351 probspkeys,
@@ -353,7 +357,7 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
353357 σ ~ noise_prior
354358 ppriors ~ product_distribution(pdist)
355359
356- sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), lastt)
360+ sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), lastt)
357361
358362 Nprobs = length(probs)
359363 Nprobs⁻¹ = inv(Nprobs)
@@ -411,7 +415,14 @@ function bayesian_datafit(prob,
411415 nchains = 4 ,
412416 niter = 1000 )
413417 (pdist, pkeys) = bayes_unpack_data(prob, p)
414- model = bayesianODE(prob, t, pdist, pkeys, data, noise_prior)
418+ model = bayesianODE(prob,
419+ first(default_algorithm(prob)),
420+ t,
421+ pdist,
422+ pkeys,
423+ last.(data),
424+ IndexKeyMap(prob, data),
425+ noise_prior)
415426 chain = Turing. sample(model,
416427 Turing. NUTS(0.65 ),
417428 mcmcensemble,
@@ -430,7 +441,15 @@ function bayesian_datafit(prob,
430441 nchains = 4 ,
431442 niter = 1_000 )
432443 pdist, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(prob, p, data)
433- model = bayesianODE(prob, pdist, pkeys, ts, lastt, timeseries, datakeys, noise_prior)
444+ model = bayesianODE(prob,
445+ first(default_algorithm(prob)),
446+ pdist,
447+ pkeys,
448+ ts,
449+ lastt,
450+ timeseries,
451+ datakeys,
452+ noise_prior)
434453 chain = Turing. sample(model,
435454 Turing. NUTS(0.65 ),
436455 mcmcensemble,
@@ -451,7 +470,10 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
451470 (pdist_, pkeys) = bayes_unpack_data(p)
452471 pdist, grouppriorsfunc = flatten(pdist_)
453472
454- model = ensemblebayesianODE(probs, t, pdist, grouppriorsfunc, pkeys, data, noise_prior)
473+ model = ensemblebayesianODE(probs,
474+ map(first ∘ default_algorithm, probs),
475+ t, pdist, grouppriorsfunc, pkeys, last.(data),
476+ map(Base. Fix2(IndexKeyMap, data), probs), noise_prior)
455477 chain = Turing. sample(model,
456478 Turing. NUTS(0.65 ),
457479 mcmcensemble,
@@ -472,6 +494,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
472494 pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(p, data)
473495 pdist, grouppriorsfunc = flatten(pdist_)
474496 model = ensemblebayesianODE(probs,
497+ map(first ∘ default_algorithm, probs),
475498 pdist,
476499 grouppriorsfunc,
477500 pkeys,
0 commit comments