Skip to content

Commit d6404ec

Browse files
committed
Fixes for MCMCChains 4.0.0
1 parent 700ae06 commit d6404ec

File tree

15 files changed

+995
-83
lines changed

15 files changed

+995
-83
lines changed

src/prob_macro.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ _setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)
233233
return Expr(:block, map(names) do n
234234
quote
235235
for vn in md.$n.vns
236-
val = copy.(vec(c[Symbol(string(vn))].value))
236+
val = vec(c[Symbol(vn)])
237237
setval!(vi, val, vn)
238238
settrans!(vi, false, vn)
239239
end

test/Turing/Turing.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Requires, Reexport, ForwardDiff
1212
using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions
1313
using Statistics, LinearAlgebra
1414
using Libtask
15-
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC
15+
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC, Bijectors
1616
using Tracker: Tracker
1717

1818
import DynamicPPL: getspace, NoDist, NamedDist
@@ -58,6 +58,11 @@ using .Variational
5858
end
5959
end
6060

61+
@init @require Optim="429524aa-4258-5aef-a3af-852621145aeb" @eval begin
62+
include("modes/ModeEstimation.jl")
63+
export MAP, MLE, optimize
64+
end
65+
6166
###########
6267
# Exports #
6368
###########
@@ -71,6 +76,7 @@ export @model, # modelling
7176

7277
MH, # classic sampling
7378
RWMH,
79+
Emcee,
7480
ESS,
7581
Gibbs,
7682

@@ -87,7 +93,7 @@ export @model, # modelling
8793
CSMC,
8894
PG,
8995

90-
vi, # variational inference
96+
vi, # variational inference
9197
ADVI,
9298

9399
sample, # inference

test/Turing/inference/AdvancedSMC.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ struct PG{space,R} <: ParticleInference
170170
resampler::R
171171
end
172172

173+
isgibbscomponent(::PG) = true
174+
173175
"""
174176
PG(n, space...)
175177
PG(n, [resampler = ResampleWithESSThreshold(), space = ()])

test/Turing/inference/Inference.jl

Lines changed: 204 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import DynamicPPL: get_matching_type,
3030
import EllipticalSliceSampling
3131
import Random
3232
import MCMCChains
33+
import StatsBase: predict
3334

3435
export InferenceAlgorithm,
3536
Hamiltonian,
@@ -40,6 +41,7 @@ export InferenceAlgorithm,
4041
SampleFromPrior,
4142
MH,
4243
ESS,
44+
Emcee,
4345
Gibbs, # classic sampling
4446
HMC,
4547
SGLD,
@@ -56,7 +58,9 @@ export InferenceAlgorithm,
5658
dot_assume,
5759
observe,
5860
dot_observe,
59-
resume
61+
resume,
62+
predict,
63+
isgibbscomponent
6064

6165
#######################
6266
# Sampler abstraction #
@@ -135,7 +139,7 @@ const TURING_INTERNAL_VARS = (internals = [
135139
"step_size",
136140
"nom_step_size",
137141
"tree_depth",
138-
"is_adapt",
142+
"is_adapt"
139143
],)
140144

141145
#########################################
@@ -305,19 +309,21 @@ Return a named tuple of parameters.
305309
getparams(t) = t.θ
306310
getparams(t::VarInfo) = tonamedtuple(TypedVarInfo(t))
307311

308-
function _params_to_array(ts)
309-
names_set = Set{String}()
312+
function _params_to_array(ts::Vector)
313+
names = Vector{Symbol}()
310314
# Extract the parameter names and values from each transition.
311315
dicts = map(ts) do t
312316
nms, vs = flatten_namedtuple(getparams(t))
313317
for nm in nms
314-
push!(names_set, nm)
318+
if !(nm in names)
319+
push!(names, nm)
320+
end
315321
end
316322
# Convert the names and values to a single dictionary.
317323
return Dict(nms[j] => vs[j] for j in 1:length(vs))
318324
end
319-
names = collect(names_set)
320-
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
325+
# names = collect(names_set)
326+
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
321327
(j, key) in enumerate(names)]
322328

323329
return names, vals
@@ -327,7 +333,7 @@ function flatten_namedtuple(nt::NamedTuple)
327333
names_vals = mapreduce(vcat, keys(nt)) do k
328334
v = nt[k]
329335
if length(v) == 1
330-
return [(string(k), v)]
336+
return [(Symbol(k), v)]
331337
else
332338
return mapreduce(vcat, zip(v[1], v[2])) do (vnval, vn)
333339
return collect(FlattenIterator(vn, vnval))
@@ -339,7 +345,7 @@ end
339345

340346
function get_transition_extras(ts::AbstractVector{<:VarInfo})
341347
valmat = reshape([getlogp(t) for t in ts], :, 1)
342-
return ["lp"], valmat
348+
return [:lp], valmat
343349
end
344350

345351
function get_transition_extras(ts::AbstractVector)
@@ -353,7 +359,7 @@ function get_transition_extras(ts::AbstractVector)
353359

354360
# Iterate through each transition.
355361
for t in ts
356-
extra_names = String[]
362+
extra_names = Symbol[]
357363
vals = []
358364

359365
# Iterate through each of the additional field names
@@ -365,11 +371,11 @@ function get_transition_extras(ts::AbstractVector)
365371
prop = getproperty(t, p)
366372
if prop isa NamedTuple
367373
for (k, v) in pairs(prop)
368-
push!(extra_names, string(k))
374+
push!(extra_names, Symbol(k))
369375
push!(vals, v)
370376
end
371377
else
372-
push!(extra_names, string(p))
378+
push!(extra_names, Symbol(p))
373379
push!(vals, prop)
374380
end
375381
end
@@ -432,12 +438,11 @@ function AbstractMCMC.bundle_samples(
432438
# Chain construction.
433439
return MCMCChains.Chains(
434440
parray,
435-
string.(nms),
441+
nms,
436442
deepcopy(TURING_INTERNAL_VARS);
437443
evidence=le,
438444
info=info,
439-
sorted=true
440-
)
445+
) |> sort
441446
end
442447

443448
# This is type piracy (for SampleFromPrior).
@@ -535,12 +540,13 @@ include("is.jl")
535540
include("AdvancedSMC.jl")
536541
include("gibbs.jl")
537542
include("../contrib/inference/sghmc.jl")
543+
include("emcee.jl")
538544

539545
################
540546
# Typing tools #
541547
################
542548

543-
for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs)
549+
for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs, :Emcee)
544550
@eval DynamicPPL.getspace(::$alg{space}) where {space} = space
545551
end
546552
for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
@@ -571,13 +577,12 @@ function get_matching_type(
571577
)
572578
return floatof(eltype(vi, spl))
573579
end
574-
function get_matching_type(
575-
spl::AbstractSampler,
576-
vi,
577-
::Type{TV},
578-
) where {T, N, TV <: Array{T, N}}
580+
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
579581
return Array{get_matching_type(spl, vi, T), N}
580582
end
583+
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where T
584+
return Array{get_matching_type(spl, vi, T)}
585+
end
581586
function get_matching_type(
582587
spl::Sampler{<:Union{PG, SMC}},
583588
vi,
@@ -593,4 +598,182 @@ end
593598
DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg)
594599
DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))
595600

601+
"""
602+
603+
predict(model::Model, chain::MCMCChains.Chains; include_all=false)
604+
605+
Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
606+
607+
If `include_all` is `false`, the returned `Chains` will contain only those variables
608+
sampled/not present in `chain`.
609+
610+
# Details
611+
Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
612+
and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
613+
614+
# Example
615+
```jldoctest
616+
julia> using Turing; Turing.turnprogress(false);
617+
[ Info: [Turing]: progress logging is disabled globally
618+
619+
julia> @model function linear_reg(x, y, σ = 0.1)
620+
β ~ Normal(0, 1)
621+
622+
for i ∈ eachindex(y)
623+
y[i] ~ Normal(β * x[i], σ)
624+
end
625+
end;
626+
627+
julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
628+
629+
julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
630+
631+
julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
632+
633+
julia> m_train = linear_reg(xs_train, ys_train, σ);
634+
635+
julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
636+
┌ Info: Found initial step size
637+
└ ϵ = 0.003125
638+
639+
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
640+
641+
julia> predictions = Turing.Inference.predict(m_test, chain_lin_reg)
642+
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
643+
644+
Iterations = 1:100
645+
Thinning interval = 1
646+
Chains = 1
647+
Samples per chain = 100
648+
parameters = y[1], y[2]
649+
650+
2-element Array{ChainDataFrame,1}
651+
652+
Summary Statistics
653+
parameters mean std naive_se mcse ess r_hat
654+
────────── ─────── ────── ──────── ─────── ──────── ──────
655+
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
656+
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
657+
658+
Quantiles
659+
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
660+
────────── ─────── ─────── ─────── ─────── ───────
661+
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
662+
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
663+
664+
665+
julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
666+
667+
julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
668+
true
669+
```
670+
"""
671+
function predict(model::Turing.Model, chain::MCMCChains.Chains; include_all = false)
672+
spl = DynamicPPL.SampleFromPrior()
673+
674+
# Sample transitions using `spl` conditioned on values in `chain`
675+
transitions = transitions_from_chain(model, chain; sampler = spl)
676+
677+
# Let the Turing internals handle everything else for you
678+
chain_result = AbstractMCMC.bundle_samples(
679+
Distributions.GLOBAL_RNG,
680+
model,
681+
spl,
682+
length(chain),
683+
transitions,
684+
MCMCChains.Chains
685+
)
686+
687+
parameter_names = if include_all
688+
names(chain_result, :parameters)
689+
else
690+
filter(k -> (k, names(chain, :parameters)), names(chain_result, :parameters))
691+
end
692+
693+
return chain_result[parameter_names]
694+
end
695+
696+
"""
697+
698+
transitions_from_chain(
699+
model::Model,
700+
chain::MCMCChains.Chains;
701+
sampler = DynamicPPL.SampleFromPrior()
702+
)
703+
704+
Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
705+
706+
The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
707+
708+
# Details
709+
710+
In a bit more detail, the process is as follows:
711+
1. For every `sample` in `chain`
712+
1. For every `variable` in `sample`
713+
1. Set `variable` in `model` to its value in `sample`
714+
2. Execute `model` with variables fixed as above, sampling variables NOT present
715+
in `chain` using `SampleFromPrior`
716+
3. Return sampled variables and log-joint
717+
718+
# Example
719+
```julia-repl
720+
julia> using Turing
721+
722+
julia> @model function demo()
723+
m ~ Normal(0, 1)
724+
x ~ Normal(m, 1)
725+
end;
726+
727+
julia> m = demo();
728+
729+
julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
730+
731+
julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
732+
733+
julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
734+
2-element Array{Float64,1}:
735+
-3.6294991938628374
736+
-2.5697948166987845
737+
738+
julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
739+
2-element Array{Array{Float64,1},1}:
740+
[-2.0844148956440796]
741+
[-1.704630494695469]
742+
```
743+
"""
744+
function transitions_from_chain(
745+
model::Turing.Model,
746+
chain::MCMCChains.Chains;
747+
sampler = DynamicPPL.SampleFromPrior()
748+
)
749+
vi = Turing.VarInfo(model)
750+
751+
transitions = map(1:length(chain)) do i
752+
c = chain[i]
753+
md = vi.metadata
754+
for v in keys(md)
755+
for vn in md[v].vns
756+
vn_symbol = Symbol(vn)
757+
if vn_symbol c.name_map.parameters
758+
val = c[vn_symbol]
759+
DynamicPPL.setval!(vi, val, vn)
760+
DynamicPPL.settrans!(vi, false, vn)
761+
else
762+
# delete so we can sample from prior
763+
DynamicPPL.set_flag!(vi, vn, "del")
764+
end
765+
end
766+
end
767+
# Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler`
768+
model(vi, sampler)
769+
770+
# Convert `VarInfo` into `NamedTuple` and save
771+
theta = DynamicPPL.tonamedtuple(vi)
772+
lp = Turing.getlogp(vi)
773+
Transition(theta, lp)
774+
end
775+
776+
return transitions
777+
end
778+
596779
end # module

0 commit comments

Comments
 (0)