@@ -30,6 +30,7 @@ import DynamicPPL: get_matching_type,
30
30
import EllipticalSliceSampling
31
31
import Random
32
32
import MCMCChains
33
+ import StatsBase: predict
33
34
34
35
export InferenceAlgorithm,
35
36
Hamiltonian,
@@ -40,6 +41,7 @@ export InferenceAlgorithm,
40
41
SampleFromPrior,
41
42
MH,
42
43
ESS,
44
+ Emcee,
43
45
Gibbs, # classic sampling
44
46
HMC,
45
47
SGLD,
@@ -56,7 +58,9 @@ export InferenceAlgorithm,
56
58
dot_assume,
57
59
observe,
58
60
dot_observe,
59
- resume
61
+ resume,
62
+ predict,
63
+ isgibbscomponent
60
64
61
65
# ######################
62
66
# Sampler abstraction #
@@ -135,7 +139,7 @@ const TURING_INTERNAL_VARS = (internals = [
135
139
" step_size" ,
136
140
" nom_step_size" ,
137
141
" tree_depth" ,
138
- " is_adapt" ,
142
+ " is_adapt"
139
143
],)
140
144
141
145
# ########################################
@@ -305,19 +309,21 @@ Return a named tuple of parameters.
305
309
getparams (t) = t. θ
306
310
getparams (t:: VarInfo ) = tonamedtuple (TypedVarInfo (t))
307
311
308
- function _params_to_array (ts)
309
- names_set = Set {String } ()
312
+ function _params_to_array (ts:: Vector )
313
+ names = Vector {Symbol } ()
310
314
# Extract the parameter names and values from each transition.
311
315
dicts = map (ts) do t
312
316
nms, vs = flatten_namedtuple (getparams (t))
313
317
for nm in nms
314
- push! (names_set, nm)
318
+ if ! (nm in names)
319
+ push! (names, nm)
320
+ end
315
321
end
316
322
# Convert the names and values to a single dictionary.
317
323
return Dict (nms[j] => vs[j] for j in 1 : length (vs))
318
324
end
319
- names = collect (names_set)
320
- vals = [get (dicts[i], key, missing ) for i in eachindex (dicts),
325
+ # names = collect(names_set)
326
+ vals = [get (dicts[i], key, missing ) for i in eachindex (dicts),
321
327
(j, key) in enumerate (names)]
322
328
323
329
return names, vals
@@ -327,7 +333,7 @@ function flatten_namedtuple(nt::NamedTuple)
327
333
names_vals = mapreduce (vcat, keys (nt)) do k
328
334
v = nt[k]
329
335
if length (v) == 1
330
- return [(string (k), v)]
336
+ return [(Symbol (k), v)]
331
337
else
332
338
return mapreduce (vcat, zip (v[1 ], v[2 ])) do (vnval, vn)
333
339
return collect (FlattenIterator (vn, vnval))
339
345
340
346
function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
341
347
valmat = reshape ([getlogp (t) for t in ts], :, 1 )
342
- return [" lp " ], valmat
348
+ return [:lp ], valmat
343
349
end
344
350
345
351
function get_transition_extras (ts:: AbstractVector )
@@ -353,7 +359,7 @@ function get_transition_extras(ts::AbstractVector)
353
359
354
360
# Iterate through each transition.
355
361
for t in ts
356
- extra_names = String []
362
+ extra_names = Symbol []
357
363
vals = []
358
364
359
365
# Iterate through each of the additional field names
@@ -365,11 +371,11 @@ function get_transition_extras(ts::AbstractVector)
365
371
prop = getproperty (t, p)
366
372
if prop isa NamedTuple
367
373
for (k, v) in pairs (prop)
368
- push! (extra_names, string (k))
374
+ push! (extra_names, Symbol (k))
369
375
push! (vals, v)
370
376
end
371
377
else
372
- push! (extra_names, string (p))
378
+ push! (extra_names, Symbol (p))
373
379
push! (vals, prop)
374
380
end
375
381
end
@@ -432,12 +438,11 @@ function AbstractMCMC.bundle_samples(
432
438
# Chain construction.
433
439
return MCMCChains. Chains (
434
440
parray,
435
- string .( nms) ,
441
+ nms,
436
442
deepcopy (TURING_INTERNAL_VARS);
437
443
evidence= le,
438
444
info= info,
439
- sorted= true
440
- )
445
+ ) |> sort
441
446
end
442
447
443
448
# This is type piracy (for SampleFromPrior).
@@ -535,12 +540,13 @@ include("is.jl")
535
540
include (" AdvancedSMC.jl" )
536
541
include (" gibbs.jl" )
537
542
include (" ../contrib/inference/sghmc.jl" )
543
+ include (" emcee.jl" )
538
544
539
545
# ###############
540
546
# Typing tools #
541
547
# ###############
542
548
543
- for alg in (:SMC , :PG , :MH , :IS , :ESS , :Gibbs )
549
+ for alg in (:SMC , :PG , :MH , :IS , :ESS , :Gibbs , :Emcee )
544
550
@eval DynamicPPL. getspace (:: $alg{space} ) where {space} = space
545
551
end
546
552
for alg in (:HMC , :HMCDA , :NUTS , :SGLD , :SGHMC )
@@ -571,13 +577,12 @@ function get_matching_type(
571
577
)
572
578
return floatof (eltype (vi, spl))
573
579
end
574
- function get_matching_type (
575
- spl:: AbstractSampler ,
576
- vi,
577
- :: Type{TV} ,
578
- ) where {T, N, TV <: Array{T, N} }
580
+ function get_matching_type (spl:: AbstractSampler , vi, :: Type{<:Array{T,N}} ) where {T,N}
579
581
return Array{get_matching_type (spl, vi, T), N}
580
582
end
583
+ function get_matching_type (spl:: AbstractSampler , vi, :: Type{<:Array{T}} ) where T
584
+ return Array{get_matching_type (spl, vi, T)}
585
+ end
581
586
function get_matching_type (
582
587
spl:: Sampler{<:Union{PG, SMC}} ,
583
588
vi,
@@ -593,4 +598,182 @@ end
593
598
DynamicPPL. getspace (spl:: Sampler ) = getspace (spl. alg)
594
599
DynamicPPL. inspace (vn:: VarName , spl:: Sampler ) = inspace (vn, getspace (spl. alg))
595
600
601
+ """
602
+
603
+ predict(model::Model, chain::MCMCChains.Chains; include_all=false)
604
+
605
+ Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
606
+
607
+ If `include_all` is `false`, the returned `Chains` will contain only those variables
608
+ sampled/not present in `chain`.
609
+
610
+ # Details
611
+ Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
612
+ and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
613
+
614
+ # Example
615
+ ```jldoctest
616
+ julia> using Turing; Turing.turnprogress(false);
617
+ [ Info: [Turing]: progress logging is disabled globally
618
+
619
+ julia> @model function linear_reg(x, y, σ = 0.1)
620
+ β ~ Normal(0, 1)
621
+
622
+ for i ∈ eachindex(y)
623
+ y[i] ~ Normal(β * x[i], σ)
624
+ end
625
+ end;
626
+
627
+ julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
628
+
629
+ julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
630
+
631
+ julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
632
+
633
+ julia> m_train = linear_reg(xs_train, ys_train, σ);
634
+
635
+ julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
636
+ ┌ Info: Found initial step size
637
+ └ ϵ = 0.003125
638
+
639
+ julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
640
+
641
+ julia> predictions = Turing.Inference.predict(m_test, chain_lin_reg)
642
+ Object of type Chains, with data of type 100×2×1 Array{Float64,3}
643
+
644
+ Iterations = 1:100
645
+ Thinning interval = 1
646
+ Chains = 1
647
+ Samples per chain = 100
648
+ parameters = y[1], y[2]
649
+
650
+ 2-element Array{ChainDataFrame,1}
651
+
652
+ Summary Statistics
653
+ parameters mean std naive_se mcse ess r_hat
654
+ ────────── ─────── ────── ──────── ─────── ──────── ──────
655
+ y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
656
+ y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
657
+
658
+ Quantiles
659
+ parameters 2.5% 25.0% 50.0% 75.0% 97.5%
660
+ ────────── ─────── ─────── ─────── ─────── ───────
661
+ y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
662
+ y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
663
+
664
+
665
+ julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
666
+
667
+ julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
668
+ true
669
+ ```
670
+ """
671
+ function predict (model:: Turing.Model , chain:: MCMCChains.Chains ; include_all = false )
672
+ spl = DynamicPPL. SampleFromPrior ()
673
+
674
+ # Sample transitions using `spl` conditioned on values in `chain`
675
+ transitions = transitions_from_chain (model, chain; sampler = spl)
676
+
677
+ # Let the Turing internals handle everything else for you
678
+ chain_result = AbstractMCMC. bundle_samples (
679
+ Distributions. GLOBAL_RNG,
680
+ model,
681
+ spl,
682
+ length (chain),
683
+ transitions,
684
+ MCMCChains. Chains
685
+ )
686
+
687
+ parameter_names = if include_all
688
+ names (chain_result, :parameters )
689
+ else
690
+ filter (k -> ∉ (k, names (chain, :parameters )), names (chain_result, :parameters ))
691
+ end
692
+
693
+ return chain_result[parameter_names]
694
+ end
695
+
696
+ """
697
+
698
+ transitions_from_chain(
699
+ model::Model,
700
+ chain::MCMCChains.Chains;
701
+ sampler = DynamicPPL.SampleFromPrior()
702
+ )
703
+
704
+ Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
705
+
706
+ The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
707
+
708
+ # Details
709
+
710
+ In a bit more detail, the process is as follows:
711
+ 1. For every `sample` in `chain`
712
+ 1. For every `variable` in `sample`
713
+ 1. Set `variable` in `model` to its value in `sample`
714
+ 2. Execute `model` with variables fixed as above, sampling variables NOT present
715
+ in `chain` using `SampleFromPrior`
716
+ 3. Return sampled variables and log-joint
717
+
718
+ # Example
719
+ ```julia-repl
720
+ julia> using Turing
721
+
722
+ julia> @model function demo()
723
+ m ~ Normal(0, 1)
724
+ x ~ Normal(m, 1)
725
+ end;
726
+
727
+ julia> m = demo();
728
+
729
+ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
730
+
731
+ julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
732
+
733
+ julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
734
+ 2-element Array{Float64,1}:
735
+ -3.6294991938628374
736
+ -2.5697948166987845
737
+
738
+ julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
739
+ 2-element Array{Array{Float64,1},1}:
740
+ [-2.0844148956440796]
741
+ [-1.704630494695469]
742
+ ```
743
+ """
744
+ function transitions_from_chain (
745
+ model:: Turing.Model ,
746
+ chain:: MCMCChains.Chains ;
747
+ sampler = DynamicPPL. SampleFromPrior ()
748
+ )
749
+ vi = Turing. VarInfo (model)
750
+
751
+ transitions = map (1 : length (chain)) do i
752
+ c = chain[i]
753
+ md = vi. metadata
754
+ for v in keys (md)
755
+ for vn in md[v]. vns
756
+ vn_symbol = Symbol (vn)
757
+ if vn_symbol ∈ c. name_map. parameters
758
+ val = c[vn_symbol]
759
+ DynamicPPL. setval! (vi, val, vn)
760
+ DynamicPPL. settrans! (vi, false , vn)
761
+ else
762
+ # delete so we can sample from prior
763
+ DynamicPPL. set_flag! (vi, vn, " del" )
764
+ end
765
+ end
766
+ end
767
+ # Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler`
768
+ model (vi, sampler)
769
+
770
+ # Convert `VarInfo` into `NamedTuple` and save
771
+ theta = DynamicPPL. tonamedtuple (vi)
772
+ lp = Turing. getlogp (vi)
773
+ Transition (theta, lp)
774
+ end
775
+
776
+ return transitions
777
+ end
778
+
596
779
end # module
0 commit comments