Skip to content

Commit 1ab6dd9

Browse files
committed
some updates; add doc
1 parent 7d0ba7c commit 1ab6dd9

File tree

5 files changed

+362
-15
lines changed

5 files changed

+362
-15
lines changed

docs/src/gibbs.md

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# `state` Interface
2+
3+
We encourage sampler packages to implement the following interface functions for the `state` type(s) they maintain:
4+
5+
```@doc
6+
get_logprob
7+
set_logprob!!
8+
get_params
9+
set_params!!
10+
```
11+
12+
These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose.
13+
14+
## Using the `state` Interface for block sampling within Gibbs
15+
16+
In this sections, we will demonstrate how a `model` package may use this `state` interface to support a Gibbs sampler that can support blocking sampling using different inference algorithms.
17+
18+
We consider a simple hierarchical model with a normal likelihood, with unknown mean and variance parameters.
19+
20+
```math
21+
\begin{align}
22+
\mu &\sim \text{Normal}(0, 1) \\
23+
\tau^2 &\sim \text{InverseGamma}(1, 1) \\
24+
x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2})
25+
\end{align}
26+
```
27+
28+
We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data.
29+
30+
```julia
31+
function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64})
32+
# mu is the mean
33+
# tau2 is the variance
34+
# x is data
35+
36+
# μ ~ Normal(0, 1)
37+
# τ² ~ InverseGamma(1, 1)
38+
# xᵢ ~ Normal(μ, √τ²)
39+
40+
logp = 0.0
41+
mu = only(mu)
42+
tau2 = only(tau2)
43+
44+
mu_prior = Normal(0, 1)
45+
logp += logpdf(mu_prior, mu)
46+
47+
tau2_prior = InverseGamma(1, 1)
48+
logp += logpdf(tau2_prior, tau2)
49+
50+
obs_prior = Normal(mu, sqrt(tau2))
51+
logp += sum(logpdf(obs_prior, xi) for xi in x)
52+
53+
return logp
54+
end
55+
```
56+
57+
To make using `LogDensityProblems` interface, we create a simple type for this model.
58+
59+
```julia
60+
abstract type AbstractHierNormal end
61+
62+
struct HierNormal <: AbstractHierNormal
63+
data::NamedTuple
64+
end
65+
66+
struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal
67+
data::NamedTuple
68+
conditioned_values::NamedTuple{conditioned_vars}
69+
end
70+
```
71+
72+
where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and
73+
74+
```julia
75+
function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple)
76+
return ConditionedHierNormal(hn.data, conditioned_values)
77+
end
78+
```
79+
80+
then we can simply write down the `LogDensityProblems` interface for this model.
81+
82+
```julia
83+
function LogDensityProblems.logdensity(
84+
hn::ConditionedHierNormal{names}, params::AbstractVector
85+
) where {names}
86+
if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2
87+
return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x)
88+
elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu
89+
return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x)
90+
else
91+
error("Unsupported conditioning configuration.")
92+
end
93+
end
94+
95+
function LogDensityProblems.capabilities(::HierNormal)
96+
return LogDensityProblems.LogDensityOrder{0}()
97+
end
98+
99+
function LogDensityProblems.capabilities(::ConditionedHierNormal)
100+
return LogDensityProblems.LogDensityOrder{0}()
101+
end
102+
```
103+
104+
the model should also define a function that allows the recomputation of the log probability given a sampler state.
105+
The reason for this is that, when we break down the joint probability into conditional probabilities, individual conditional probability problems are conditional on the values of the other variables.
106+
Between the Gibbs sampler sweeps, the values of the variables may change, and we need to recompute the log probability of the current state.
107+
108+
A recomputation function could use the `state` interface to return a new state with the updated log probability.
109+
E.g.
110+
111+
```julia
112+
function recompute_logprob!!(hn::ConditionedHierNormal, vals, state)
113+
return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals))
114+
end
115+
```
116+
117+
where the model doesn't need to know the details of the `state` type, as long as it can access the `log_joint` function.
118+
119+
## Sampler Packages
120+
121+
To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal.
122+
123+
Although the interface doesn't force the sampler to implement `Transition` and `State` types, in practice, it has been the convention to do so.
124+
125+
Here we define some bare minimum types to represent the transitions and states.
126+
127+
```julia
128+
struct MHTransition{T}
129+
params::Vector{T}
130+
end
131+
132+
struct MHState{T}
133+
params::Vector{T}
134+
logp::Float64
135+
end
136+
```
137+
138+
Next we define the four `state` interface functions.
139+
140+
```julia
141+
AbstractMCMC.get_params(state::MHState) = state.params
142+
AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp)
143+
AbstractMCMC.get_logprob(state::MHState) = state.logp
144+
AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp)
145+
```
146+
147+
These are the functions that was used in the `recompute_logprob!!` function above.
148+
149+
It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `get_logprob` to easily read the log probability of the current state.
150+
151+
```julia
152+
struct RWMH <: AbstractMCMC.AbstractSampler
153+
σ::Float64
154+
end
155+
156+
function AbstractMCMC.step(
157+
rng::AbstractRNG,
158+
logdensity_model::AbstractMCMC.LogDensityModel,
159+
sampler::RWMH,
160+
args...;
161+
initial_params,
162+
kwargs...,
163+
)
164+
return MHTransition(initial_params),
165+
MHState(
166+
initial_params,
167+
only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)),
168+
)
169+
end
170+
171+
function AbstractMCMC.step(
172+
rng::AbstractRNG,
173+
logdensity_model::AbstractMCMC.LogDensityModel,
174+
sampler::RWMH,
175+
state::MHState,
176+
args...;
177+
kwargs...,
178+
)
179+
params = state.params
180+
proposal_dist = MvNormal(zeros(length(params)), sampler.σ)
181+
proposal = params .+ rand(rng, proposal_dist)
182+
logp_proposal = only(
183+
LogDensityProblems.logdensity(logdensity_model.logdensity, proposal)
184+
)
185+
186+
log_acceptance_ratio = min(0, logp_proposal - get_logprob(state))
187+
188+
if log(rand(rng)) < log_acceptance_ratio
189+
return MHTransition(proposal), MHState(proposal, logp_proposal)
190+
else
191+
return MHTransition(params), MHState(params, get_logprob(state))
192+
end
193+
end
194+
```
195+
196+
```julia
197+
struct PriorMH <: AbstractMCMC.AbstractSampler
198+
prior_dist::Distribution
199+
end
200+
201+
function AbstractMCMC.step(
202+
rng::AbstractRNG,
203+
logdensity_model::AbstractMCMC.LogDensityModel,
204+
sampler::PriorMH,
205+
args...;
206+
initial_params,
207+
kwargs...,
208+
)
209+
return MHTransition(initial_params),
210+
MHState(
211+
initial_params,
212+
only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)),
213+
)
214+
end
215+
216+
function AbstractMCMC.step(
217+
rng::AbstractRNG,
218+
logdensity_model::AbstractMCMC.LogDensityModel,
219+
sampler::PriorMH,
220+
state::MHState,
221+
args...;
222+
kwargs...,
223+
)
224+
params = get_params(state)
225+
proposal_dist = sampler.prior_dist
226+
proposal = rand(rng, proposal_dist)
227+
logp_proposal = only(
228+
LogDensityProblems.logdensity(logdensity_model.logdensity, proposal)
229+
)
230+
231+
log_acceptance_ratio = min(
232+
0,
233+
logp_proposal - get_logprob(state) + logpdf(proposal_dist, params) -
234+
logpdf(proposal_dist, proposal),
235+
)
236+
237+
if log(rand(rng)) < log_acceptance_ratio
238+
return MHTransition(proposal), MHState(proposal, logp_proposal)
239+
else
240+
return MHTransition(params), MHState(params, get_logprob(state))
241+
end
242+
end
243+
```
244+
245+
At last, we can proceed to implement the Gibbs sampler.
246+
247+
```julia
248+
struct Gibbs <: AbstractMCMC.AbstractSampler
249+
sampler_map::OrderedDict
250+
end
251+
252+
struct GibbsState
253+
vi::NamedTuple
254+
states::OrderedDict
255+
end
256+
257+
struct GibbsTransition
258+
values::NamedTuple
259+
end
260+
261+
function AbstractMCMC.step(
262+
rng::AbstractRNG,
263+
logdensity_model::AbstractMCMC.LogDensityModel,
264+
spl::Gibbs,
265+
args...;
266+
initial_params::NamedTuple,
267+
kwargs...,
268+
)
269+
states = OrderedDict()
270+
for group in keys(spl.sampler_map)
271+
sub_spl = spl.sampler_map[group]
272+
273+
vars_to_be_conditioned_on = setdiff(keys(initial_params), group)
274+
cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}(
275+
Tuple([initial_params[g] for g in vars_to_be_conditioned_on])
276+
)
277+
params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group]))
278+
sub_state = last(
279+
AbstractMCMC.step(
280+
rng,
281+
AbstractMCMC.LogDensityModel(
282+
condition(logdensity_model.logdensity, cond_val)
283+
),
284+
sub_spl,
285+
args...;
286+
initial_params=flatten(params_val),
287+
kwargs...,
288+
),
289+
)
290+
states[group] = sub_state
291+
end
292+
return GibbsTransition(initial_params), GibbsState(initial_params, states)
293+
end
294+
295+
function AbstractMCMC.step(
296+
rng::AbstractRNG,
297+
logdensity_model::AbstractMCMC.LogDensityModel,
298+
spl::Gibbs,
299+
state::GibbsState,
300+
args...;
301+
kwargs...,
302+
)
303+
vi = state.vi
304+
for group in keys(spl.sampler_map)
305+
for (group, sub_state) in state.states
306+
vi = merge(vi, unflatten(get_params(sub_state), group))
307+
end
308+
sub_spl = spl.sampler_map[group]
309+
sub_state = state.states[group]
310+
group_complement = setdiff(keys(vi), group)
311+
cond_val = NamedTuple{Tuple(group_complement)}(
312+
Tuple([vi[g] for g in group_complement])
313+
)
314+
cond_logdensity = condition(logdensity_model.logdensity, cond_val)
315+
sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state)
316+
sub_state = last(
317+
AbstractMCMC.step(
318+
rng,
319+
AbstractMCMC.LogDensityModel(cond_logdensity),
320+
sub_spl,
321+
sub_state,
322+
args...;
323+
kwargs...,
324+
),
325+
)
326+
state.states[group] = sub_state
327+
end
328+
for (group, sub_state) in state.states
329+
vi = merge(vi, unflatten(get_params(sub_state), group))
330+
end
331+
return GibbsTransition(vi), GibbsState(vi, state.states)
332+
end
333+
```
334+
335+
Some points worth noting:
336+
337+
1. We are using `OrderedDict` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps.
338+
2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration.
339+
3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem.
340+
4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following:
341+
- first update the values from the last step of the sweep into the `vi`, which stores the values of all variables at the moment of the Gibbs sweep.
342+
- condition on the values of all variables that are not in the current group
343+
- recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed
344+
- perform a step of the sampler for the conditional probability problem, and update the sampler state
345+
- update the `vi` with the new values from the sampler state
346+
347+
Again, the `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states.

gibbs_example/gibbs.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using AbstractMCMC
2-
using LogDensityProblems, Distributions, LinearAlgebra, Random
2+
using Distributions
3+
using LogDensityProblems
34
using OrderedCollections
5+
using Random
46

57
##
68

@@ -62,7 +64,7 @@ function AbstractMCMC.step(
6264
vi = state.vi
6365
for group in keys(spl.sampler_map)
6466
for (group, sub_state) in state.states
65-
vi = merge(vi, unflatten(getparams(sub_state), group))
67+
vi = merge(vi, unflatten(get_params(sub_state), group))
6668
end
6769
sub_spl = spl.sampler_map[group]
6870
sub_state = state.states[group]
@@ -71,7 +73,7 @@ function AbstractMCMC.step(
7173
Tuple([vi[g] for g in group_complement])
7274
)
7375
cond_logdensity = condition(logdensity_model.logdensity, cond_val)
74-
sub_state = recompute_logprob!!(cond_logdensity, getparams(sub_state), sub_state)
76+
sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state)
7577
sub_state = last(
7678
AbstractMCMC.step(
7779
rng,
@@ -85,7 +87,7 @@ function AbstractMCMC.step(
8587
state.states[group] = sub_state
8688
end
8789
for (group, sub_state) in state.states
88-
vi = merge(vi, unflatten(getparams(sub_state), group))
90+
vi = merge(vi, unflatten(get_params(sub_state), group))
8991
end
9092
return GibbsTransition(vi), GibbsState(vi, state.states)
9193
end

gibbs_example/hier_normal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,5 @@ function unflatten(vec::AbstractVector, group::Tuple)
6969
end
7070

7171
function recompute_logprob!!(hn::ConditionedHierNormal, vals, state)
72-
return setlogp!!(state, LogDensityProblems.logdensity(hn, vals))
72+
return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals))
7373
end

0 commit comments

Comments
 (0)