Skip to content

Commit 2cda98d

Browse files
AoifeHughesAoifeHughesmhaurupenelopeysm
authored
Implement GibbsConditional (#2647)
WRT: #2547 --------- Co-authored-by: AoifeHughes <[email protected]> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent 19bf7d6 commit 2cda98d

File tree

9 files changed

+481
-3
lines changed

9 files changed

+481
-3
lines changed

HISTORY.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# 0.41.2
2+
3+
Add `GibbsConditional`, a "sampler" that can be used to provide analytically known conditional posteriors in a Gibbs sampler.
4+
5+
In Gibbs sampling, some variables are sampled with a component sampler, while holding other variables conditioned to their current values. Usually one e.g. takes turns sampling one variable with HMC and the other with a particle sampler. However, sometimes the posterior distribution of one variable is known analytically, given the conditioned values of other variables. `GibbsConditional` provides a way to implement these analytically known conditional posteriors and use them as component samplers for Gibbs. See the docstring of `GibbsConditional` for details.
6+
7+
Note that `GibbsConditional` used to exist in Turing.jl until v0.36, at which it was removed when the whole Gibbs sampler was rewritten. This reintroduces the same functionality, though with a slightly different interface.
8+
19
# 0.41.1
210

311
The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.41.1"
3+
version = "0.41.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
6363
| `Emcee` | [`Turing.Inference.Emcee`](@ref) | Affine-invariant ensemble sampler |
6464
| `ESS` | [`Turing.Inference.ESS`](@ref) | Elliptical slice sampling |
6565
| `Gibbs` | [`Turing.Inference.Gibbs`](@ref) | Gibbs sampling |
66+
| `GibbsConditional` | [`Turing.Inference.GibbsConditional`](@ref) | Gibbs sampling with analytical conditional posterior distributions |
6667
| `HMC` | [`Turing.Inference.HMC`](@ref) | Hamiltonian Monte Carlo |
6768
| `SGLD` | [`Turing.Inference.SGLD`](@ref) | Stochastic gradient Langevin dynamics |
6869
| `SGHMC` | [`Turing.Inference.SGHMC`](@ref) | Stochastic gradient Hamiltonian Monte Carlo |

src/Turing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export
102102
Emcee,
103103
ESS,
104104
Gibbs,
105+
GibbsConditional,
105106
HMC,
106107
SGLD,
107108
SGHMC,

src/mcmc/Inference.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ export Hamiltonian,
5656
ESS,
5757
Emcee,
5858
Gibbs, # classic sampling
59+
GibbsConditional, # conditional sampling
5960
HMC,
6061
SGLD,
6162
PolynomialStepsize,
@@ -430,6 +431,7 @@ include("mh.jl")
430431
include("is.jl")
431432
include("particle_mcmc.jl")
432433
include("gibbs.jl")
434+
include("gibbs_conditional.jl")
433435
include("sghmc.jl")
434436
include("emcee.jl")
435437
include("prior.jl")

src/mcmc/gibbs_conditional.jl

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
GibbsConditional(get_cond_dists)
3+
4+
A Gibbs component sampler that samples variables according to user-provided analytical
5+
conditional posterior distributions.
6+
7+
When using Gibbs sampling, sometimes one may know the analytical form of the posterior for
8+
a given variable, given the conditioned values of the other variables. In such cases one can
9+
use `GibbsConditional` as a component sampler to to sample from these known conditionals
10+
directly, avoiding any MCMC methods. One does so with
11+
12+
```julia
13+
sampler = Gibbs(
14+
(@varname(var1), @varname(var2)) => GibbsConditional(get_cond_dists),
15+
other samplers go here...
16+
)
17+
```
18+
19+
Here `get_cond_dists(c::Dict{<:VarName})` should be a function that takes a `Dict` mapping
20+
the conditioned variables (anything other than `var1` and `var2`) to their values, and
21+
returns the conditional posterior distributions for `var1` and `var2`. You may, of course,
22+
have any number of variables being sampled as a block in this manner, we only use two as an
23+
example. The return value of `get_cond_dists` should be one of the following:
24+
- A single `Distribution`, if only one variable is being sampled.
25+
- An `AbstractDict{<:VarName,<:Distribution}` that maps the variables being sampled to their
26+
conditional posteriors E.g. `Dict(@varname(var1) => dist1, @varname(var2) => dist2)`.
27+
- A `NamedTuple` of `Distribution`s, which is like the `AbstractDict` case but can be used
28+
if all the variable names are single `Symbol`s, and may be more performant. E.g.
29+
`(; var1=dist1, var2=dist2)`.
30+
31+
# Examples
32+
33+
```julia
34+
# Define a model
35+
@model function inverse_gdemo(x)
36+
precision ~ Gamma(2, inv(3))
37+
std = sqrt(1 / precision)
38+
m ~ Normal(0, std)
39+
for i in eachindex(x)
40+
x[i] ~ Normal(m, std)
41+
end
42+
end
43+
44+
# Define analytical conditionals. See
45+
# https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution
46+
function cond_precision(c)
47+
a = 2.0
48+
b = 3.0
49+
# We use AbstractPPL.getvalue instead of indexing into `c` directly to guard against
50+
# issues where e.g. you try to get `c[@varname(x[1])]` but only `@varname(x)` is present
51+
# in `c`. `getvalue` handles that gracefully, `getindex` doesn't. In this case
52+
# `getindex` would suffice, but `getvalue` is good practice.
53+
m = AbstractPPL.getvalue(c, @varname(m))
54+
x = AbstractPPL.getvalue(c, @varname(x))
55+
n = length(x)
56+
a_new = a + (n + 1) / 2
57+
b_new = b + sum(abs2, x .- m) / 2 + m^2 / 2
58+
return Gamma(a_new, 1 / b_new)
59+
end
60+
61+
function cond_m(c)
62+
precision = AbstractPPL.getvalue(c, @varname(precision))
63+
x = AbstractPPL.getvalue(c, @varname(x))
64+
n = length(x)
65+
m_mean = sum(x) / (n + 1)
66+
m_var = 1 / (precision * (n + 1))
67+
return Normal(m_mean, sqrt(m_var))
68+
end
69+
70+
# Sample using GibbsConditional
71+
model = inverse_gdemo([1.0, 2.0, 3.0])
72+
chain = sample(model, Gibbs(
73+
:precision => GibbsConditional(cond_precision),
74+
:m => GibbsConditional(cond_m)
75+
), 1000)
76+
```
77+
"""
78+
struct GibbsConditional{C} <: AbstractSampler
79+
get_cond_dists::C
80+
end
81+
82+
isgibbscomponent(::GibbsConditional) = true
83+
84+
"""
85+
build_variable_dict(model::DynamicPPL.Model)
86+
87+
Traverse the context stack of `model` and build a `Dict` of all the variable values that are
88+
set in GibbsContext, ConditionContext, or FixedContext.
89+
"""
90+
function build_variable_dict(model::DynamicPPL.Model)
91+
context = model.context
92+
cond_vals = DynamicPPL.conditioned(context)
93+
fixed_vals = DynamicPPL.fixed(context)
94+
# TODO(mhauru) Can we avoid invlinking all the time?
95+
global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model)
96+
# TODO(mhauru) This creates a lot of Dicts, which are then immediately merged into one.
97+
# Also, DynamicPPL.to_varname_dict is known to be inefficient. Make a more efficient
98+
# implementation.
99+
return merge(
100+
DynamicPPL.values_as(global_vi, Dict),
101+
DynamicPPL.to_varname_dict(cond_vals),
102+
DynamicPPL.to_varname_dict(fixed_vals),
103+
DynamicPPL.to_varname_dict(model.args),
104+
)
105+
end
106+
107+
function get_gibbs_global_varinfo(context::DynamicPPL.AbstractContext)
108+
return if context isa GibbsContext
109+
get_global_varinfo(context)
110+
elseif DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent
111+
get_gibbs_global_varinfo(DynamicPPL.childcontext(context))
112+
else
113+
msg = """No GibbsContext found in context stack. Are you trying to use \
114+
GibbsConditional outside of Gibbs?
115+
"""
116+
throw(ArgumentError(msg))
117+
end
118+
end
119+
120+
function initialstep(
121+
::Random.AbstractRNG,
122+
model::DynamicPPL.Model,
123+
::GibbsConditional,
124+
vi::DynamicPPL.AbstractVarInfo;
125+
kwargs...,
126+
)
127+
state = DynamicPPL.is_transformed(vi) ? DynamicPPL.invlink(vi, model) : vi
128+
# Since GibbsConditional is only used within Gibbs, it does not need to return a
129+
# transition.
130+
return nothing, state
131+
end
132+
133+
function AbstractMCMC.step(
134+
rng::Random.AbstractRNG,
135+
model::DynamicPPL.Model,
136+
sampler::GibbsConditional,
137+
state::DynamicPPL.AbstractVarInfo;
138+
kwargs...,
139+
)
140+
# Get all the conditioned variable values from the model context. This is assumed to
141+
# include a GibbsContext as part of the context stack.
142+
condvals = build_variable_dict(model)
143+
conddists = sampler.get_cond_dists(condvals)
144+
145+
# We support three different kinds of return values for `sample.get_cond_dists`, to make
146+
# life easier for the user.
147+
if conddists isa AbstractDict
148+
for (vn, dist) in conddists
149+
state = setindex!!(state, rand(rng, dist), vn)
150+
end
151+
elseif conddists isa NamedTuple
152+
for (vn_sym, dist) in pairs(conddists)
153+
vn = VarName{vn_sym}()
154+
state = setindex!!(state, rand(rng, dist), vn)
155+
end
156+
else
157+
# Single variable case
158+
vn = only(keys(state))
159+
state = setindex!!(state, rand(rng, conddists), vn)
160+
end
161+
162+
# Since GibbsConditional is only used within Gibbs, it does not need to return a
163+
# transition.
164+
return nothing, state
165+
end
166+
167+
function setparams_varinfo!!(
168+
::DynamicPPL.Model, ::GibbsConditional, ::Any, params::DynamicPPL.AbstractVarInfo
169+
)
170+
return params
171+
end

test/mcmc/gibbs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ end
496496

497497
@testset "dynamic model with analytical posterior" begin
498498
# A dynamic model where b ~ Bernoulli determines the dimensionality
499-
# When b=0: single parameter θ₁
499+
# When b=0: single parameter θ₁
500500
# When b=1: two parameters θ₁, θ₂ where we observe their sum
501501
@model function dynamic_bernoulli_normal(y_obs=2.0)
502502
b ~ Bernoulli(0.3)
@@ -575,7 +575,7 @@ end
575575
# end
576576
# end
577577
# sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000)
578-
#
578+
#
579579
# because the number of observations in each particle depends on the value
580580
# of `a`.
581581
#

0 commit comments

Comments
 (0)