Skip to content

Commit 0d40f40

Browse files
torfjeldedevmotion
andauthored
predict implemented for multiple chains (#1421)
* made predict work for multiple chains * export predict from Turing * added appropriate tests * fixed tests * fixed tests and implementation * added rng * Update src/inference/Inference.jl Co-authored-by: David Widmann <[email protected]> * version bump Co-authored-by: David Widmann <[email protected]>
1 parent d77181e commit 0d40f40

File tree

4 files changed

+68
-14
lines changed

4 files changed

+68
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.14.5"
3+
version = "0.14.6"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/Turing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ end
6868
###########
6969
# `using` statements for stuff to re-export
7070
using DynamicPPL: elementwise_loglikelihoods, generated_quantities, logprior, logjoint
71+
using StatsBase: predict
7172

7273
# Turing essentials - modelling macros and inference algorithms
7374
export @model, # modelling
@@ -120,6 +121,7 @@ export @model, # modelling
120121
filldist,
121122
arraydist,
122123

124+
predict,
123125
elementwise_loglikelihoods,
124126
genereated_quantities,
125127
logprior,

src/inference/Inference.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))
599599

600600
"""
601601
602-
predict(model::Model, chain::MCMCChains.Chains; include_all=false)
602+
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
603603
604604
Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
605605
@@ -637,7 +637,7 @@ julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
637637
638638
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
639639
640-
julia> predictions = Turing.Inference.predict(m_test, chain_lin_reg)
640+
julia> predictions = predict(m_test, chain_lin_reg)
641641
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
642642
643643
Iterations = 1:100
@@ -667,20 +667,30 @@ julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
667667
true
668668
```
669669
"""
670-
function predict(model::Turing.Model, chain::MCMCChains.Chains; include_all = false)
670+
function predict(model::Model, chain::MCMCChains.Chains; kwargs...)
671+
return predict(Random.GLOBAL_RNG, model, chain; kwargs...)
672+
end
673+
function predict(rng::AbstractRNG, model::Model, chain::MCMCChains.Chains; include_all = false)
671674
spl = DynamicPPL.SampleFromPrior()
672675

673676
# Sample transitions using `spl` conditioned on values in `chain`
674-
transitions = transitions_from_chain(model, chain; sampler = spl)
677+
transitions = [
678+
transitions_from_chain(rng, model, chain[:, :, chn_idx]; sampler = spl)
679+
for chn_idx = 1:size(chain, 3)
680+
]
675681

676682
# Let the Turing internals handle everything else for you
677-
chain_result = AbstractMCMC.bundle_samples(
678-
Distributions.GLOBAL_RNG,
679-
model,
680-
spl,
681-
length(chain),
682-
transitions,
683-
MCMCChains.Chains
683+
chain_result = reduce(
684+
MCMCChains.chainscat, [
685+
AbstractMCMC.bundle_samples(
686+
rng,
687+
model,
688+
spl,
689+
length(chain),
690+
transitions[chn_idx],
691+
MCMCChains.Chains
692+
) for chn_idx = 1:size(chain, 3)
693+
]
684694
)
685695

686696
parameter_names = if include_all
@@ -695,6 +705,7 @@ end
695705
"""
696706
697707
transitions_from_chain(
708+
[rng::AbstractRNG,]
698709
model::Model,
699710
chain::MCMCChains.Chains;
700711
sampler = DynamicPPL.SampleFromPrior()
@@ -741,6 +752,14 @@ julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
741752
```
742753
"""
743754
function transitions_from_chain(
755+
model::Turing.Model,
756+
chain::MCMCChains.Chains;
757+
kwargs...
758+
)
759+
return transitions_from_chain(Random.GLOBAL_RNG, model, chain; kwargs...)
760+
end
761+
function transitions_from_chain(
762+
rng::AbstractRNG,
744763
model::Turing.Model,
745764
chain::MCMCChains.Chains;
746765
sampler = DynamicPPL.SampleFromPrior()
@@ -774,7 +793,7 @@ function transitions_from_chain(
774793
end
775794
end
776795
# Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler`
777-
model(vi, sampler)
796+
model(rng, vi, sampler)
778797

779798
# Convert `VarInfo` into `NamedTuple` and save
780799
theta = DynamicPPL.tonamedtuple(vi)

test/inference/utilities.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,50 @@ using Random
2727
chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), 200);
2828

2929
# Predict on two last indices
30-
m_lin_reg_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)));
30+
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)));
3131
predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
3232

3333
ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1))
3434

3535
@test sum(abs2, ys_test - ys_pred) 0.1
3636

37+
# Ensure that `rng` is respected
38+
predictions1 = let rng = MersenneTwister(42)
39+
predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
40+
end
41+
predictions2 = let rng = MersenneTwister(42)
42+
predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
43+
end
44+
@test all(Array(predictions1) .== Array(predictions2))
45+
3746
# Predict on two last indices for vectorized
3847
m_lin_reg_test = linear_reg_vec(xs_test, missing);
3948
predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
4049
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims = 1))
4150

4251
@test sum(abs2, ys_test - ys_pred_vec) 0.1
52+
53+
# Multiple chains
54+
chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), MCMCThreads(), 200, 2);
55+
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)));
56+
predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
57+
58+
@test size(chain_lin_reg, 3) == size(predictions, 3)
59+
60+
for chain_idx in MCMCChains.chains(chain_lin_reg)
61+
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims = 1))
62+
@test sum(abs2, ys_test - ys_pred) 0.1
63+
end
64+
65+
# Predict on two last indices for vectorized
66+
m_lin_reg_test = linear_reg_vec(xs_test, missing);
67+
predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
68+
69+
for chain_idx in MCMCChains.chains(chain_lin_reg)
70+
ys_pred_vec = vec(mean(
71+
Array(group(predictions_vec[:, :, chain_idx], :y));
72+
dims = 1
73+
))
74+
@test sum(abs2, ys_test - ys_pred_vec) 0.1
75+
end
4376
end

0 commit comments

Comments
 (0)