@@ -599,7 +599,7 @@ DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))
599
599
600
600
"""
601
601
602
- predict(model::Model, chain::MCMCChains.Chains; include_all=false)
602
+ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
603
603
604
604
Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
605
605
@@ -637,7 +637,7 @@ julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
637
637
638
638
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
639
639
640
- julia> predictions = Turing.Inference. predict(m_test, chain_lin_reg)
640
+ julia> predictions = predict(m_test, chain_lin_reg)
641
641
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
642
642
643
643
Iterations = 1:100
@@ -667,20 +667,30 @@ julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
667
667
true
668
668
```
669
669
"""
670
- function predict (model:: Turing.Model , chain:: MCMCChains.Chains ; include_all = false )
670
+ function predict (model:: Model , chain:: MCMCChains.Chains ; kwargs... )
671
+ return predict (Random. GLOBAL_RNG, model, chain; kwargs... )
672
+ end
673
+ function predict (rng:: AbstractRNG , model:: Model , chain:: MCMCChains.Chains ; include_all = false )
671
674
spl = DynamicPPL. SampleFromPrior ()
672
675
673
676
# Sample transitions using `spl` conditioned on values in `chain`
674
- transitions = transitions_from_chain (model, chain; sampler = spl)
677
+ transitions = [
678
+ transitions_from_chain (rng, model, chain[:, :, chn_idx]; sampler = spl)
679
+ for chn_idx = 1 : size (chain, 3 )
680
+ ]
675
681
676
682
# Let the Turing internals handle everything else for you
677
- chain_result = AbstractMCMC. bundle_samples (
678
- Distributions. GLOBAL_RNG,
679
- model,
680
- spl,
681
- length (chain),
682
- transitions,
683
- MCMCChains. Chains
683
+ chain_result = reduce (
684
+ MCMCChains. chainscat, [
685
+ AbstractMCMC. bundle_samples (
686
+ rng,
687
+ model,
688
+ spl,
689
+ length (chain),
690
+ transitions[chn_idx],
691
+ MCMCChains. Chains
692
+ ) for chn_idx = 1 : size (chain, 3 )
693
+ ]
684
694
)
685
695
686
696
parameter_names = if include_all
695
705
"""
696
706
697
707
transitions_from_chain(
708
+ [rng::AbstractRNG,]
698
709
model::Model,
699
710
chain::MCMCChains.Chains;
700
711
sampler = DynamicPPL.SampleFromPrior()
@@ -741,6 +752,14 @@ julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
741
752
```
742
753
"""
743
754
function transitions_from_chain (
755
+ model:: Turing.Model ,
756
+ chain:: MCMCChains.Chains ;
757
+ kwargs...
758
+ )
759
+ return transitions_from_chain (Random. GLOBAL_RNG, model, chain; kwargs... )
760
+ end
761
+ function transitions_from_chain (
762
+ rng:: AbstractRNG ,
744
763
model:: Turing.Model ,
745
764
chain:: MCMCChains.Chains ;
746
765
sampler = DynamicPPL. SampleFromPrior ()
@@ -774,7 +793,7 @@ function transitions_from_chain(
774
793
end
775
794
end
776
795
# Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler`
777
- model (vi, sampler)
796
+ model (rng, vi, sampler)
778
797
779
798
# Convert `VarInfo` into `NamedTuple` and save
780
799
theta = DynamicPPL. tonamedtuple (vi)
0 commit comments