Skip to content

Commit 6fde198

Browse files
committed
update code and doc
1 parent 3ed5cb3 commit 6fde198

File tree

4 files changed

+80
-161
lines changed

4 files changed

+80
-161
lines changed

docs/src/gibbs.md

Lines changed: 55 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, st
88

99
This function takes the logdensity model and the state, and returns the log probability of the state.
1010
If `recompute_logp` is `true`, it should recompute the log probability of the state.
11-
Otherwise, it should use the log probability stored in the state.
11+
Otherwise, it could use the log probability stored in the state.
1212

1313
```julia
1414
Base.vec(state)
@@ -20,9 +20,11 @@ This function takes the state and returns a vector of the parameter values store
2020
(state::StateType)(logp::Float64)
2121
```
2222

23-
This function takes the state and a log probability value, and updates the state with the new log probability.
23+
This function takes the state and a log probability value, and returns a new state with the updated log probability.
2424

25-
These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose.
25+
These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement.
26+
The interface facilitates the implementation of "meta-algorithms" that combine different samplers.
27+
We will demonstrate how it can be used to implement Gibbs sampling in the following sections.
2628

2729
## Using the `state` Interface for block sampling within Gibbs
2830

@@ -122,7 +124,7 @@ function LogDensityProblems.capabilities(::ConditionedHierNormal)
122124
end
123125
```
124126

125-
## Sampler Packages
127+
### Implementing A Sampler with `AbstractMCMC` Interface
126128

127129
To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal.
128130

@@ -258,15 +260,11 @@ function compute_log_acceptance_ratio(
258260
end
259261
```
260262

261-
At last, we can proceed to implement the Gibbs sampler.
263+
At last, we can proceed to implement a very simple Gibbs sampler.
262264

263265
```julia
264-
"""
265-
Gibbs(sampler_map::NamedTuple)
266-
267-
A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter.
268-
"""
269266
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
267+
"Maps variables to their samplers."
270268
sampler_map::T
271269
end
272270

@@ -291,16 +289,18 @@ end
291289
292290
Update the trace with the values from the MCMC states of the sub-problems.
293291
"""
294-
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
295-
for parameter_variable in keys(gibbs_state.mcmc_states)
292+
function update_trace(
293+
trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT}
294+
) where {trace_names,TraceNT,StateNT,SizeNT}
295+
for parameter_variable in fieldnames(StateNT)
296296
sub_state = gibbs_state.mcmc_states[parameter_variable]
297-
sub_state_params = Base.vec(sub_state)
298-
unflattened_sub_state_params = unflatten(
299-
sub_state_params,
300-
NamedTuple{(parameter_variable,)}((
301-
gibbs_state.variable_sizes[parameter_variable],
302-
)),
297+
sub_state_params_values = Base.vec(sub_state)
298+
reshaped_sub_state_params_values = reshape(
299+
sub_state_params_values, gibbs_state.variable_sizes[parameter_variable]
303300
)
301+
unflattened_sub_state_params = NamedTuple{(parameter_variable,)}((
302+
reshaped_sub_state_params_values,
303+
))
304304
trace = merge(trace, unflattened_sub_state_params)
305305
end
306306
return trace
@@ -321,8 +321,7 @@ end
321321
function AbstractMCMC.step(
322322
rng::Random.AbstractRNG,
323323
logdensity_model::AbstractMCMC.LogDensityModel,
324-
sampler::Gibbs{Tsamplingmap},
325-
args...;
324+
sampler::Gibbs{Tsamplingmap};
326325
initial_params::NamedTuple,
327326
kwargs...,
328327
) where {Tsamplingmap}
@@ -338,30 +337,27 @@ function AbstractMCMC.step(
338337
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
339338
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
340339
)
341-
sub_problem_parameters_values = NamedTuple{(parameter_variable,)}((
342-
initial_params[parameter_variable],
343-
))
344340

345341
# LogDensityProblems' `logdensity` function expects a single vector of real numbers
346342
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
347343
# and unflatten after the sampling step
348-
flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values)
344+
flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable])
349345

346+
sub_logdensity_model = AbstractMCMC.LogDensityModel(
347+
AbstractPPL.condition(
348+
logdensity_model.logdensity, conditioning_variables_values
349+
),
350+
)
350351
sub_state = last(
351352
AbstractMCMC.step(
352353
rng,
353-
AbstractMCMC.LogDensityModel(
354-
AbstractPPL.condition(
355-
logdensity_model.logdensity, conditioning_variables_values
356-
),
357-
),
358-
sub_sampler,
359-
args...;
354+
sub_logdensity_model,
355+
sub_sampler;
360356
initial_params=flattened_sub_problem_parameters_values,
361357
kwargs...,
362358
),
363359
)
364-
(sub_state, Tuple(size(initial_params[parameter_variable])))
360+
(sub_state, size(initial_params[parameter_variable]))
365361
end
366362

367363
mcmc_states_tuple = first.(results)
@@ -382,11 +378,12 @@ function AbstractMCMC.step(
382378
rng::Random.AbstractRNG,
383379
logdensity_model::AbstractMCMC.LogDensityModel,
384380
sampler::Gibbs{Tsamplingmap},
385-
gibbs_state::GibbsState,
386-
args...;
381+
gibbs_state::GibbsState;
387382
kwargs...,
388383
) where {Tsamplingmap}
389-
(; trace, mcmc_states, variable_sizes) = gibbs_state
384+
trace = gibbs_state.trace
385+
mcmc_states = gibbs_state.mcmc_states
386+
variable_sizes = gibbs_state.variable_sizes
390387

391388
model_parameter_names = fieldnames(Tsamplingmap)
392389
mcmc_states = map(model_parameter_names) do parameter_variable
@@ -407,7 +404,7 @@ function AbstractMCMC.step(
407404
sub_state = (sub_state)(logp)
408405
sub_state = last(
409406
AbstractMCMC.step(
410-
rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs...
407+
rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...
411408
),
412409
)
413410
trace = update_trace(trace, gibbs_state)
@@ -419,53 +416,36 @@ function AbstractMCMC.step(
419416
end
420417
```
421418

422-
where we use two utility functions `flatten` and `unflatten` to convert between the single vector of real numbers and the named tuple of parameters.
419+
We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`.
423420

424-
```julia
425-
"""
426-
flatten(trace::NamedTuple)
427-
428-
Flatten all the values in the trace into a single vector. Variable names information is discarded.
429-
"""
430-
function flatten(trace::NamedTuple)
431-
return reduce(vcat, vec.(values(trace)))
432-
end
421+
We uses the `AbstractPPL.condition` to devide the full model into smaller conditional probability problems.
422+
And each conditional probability problem corresponds to a sampler and corresponding state.
433423

434-
"""
435-
unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple})
424+
The `Gibbs` sampler has the same interface as other samplers in `AbstractMCMC` (we don't implement the above state interface for `GibbsState` to keep it simple, but it can be implemented similarly).
436425

437-
Reverse operation of flatten. Reshape the vector into the original arrays using size information.
438-
"""
439-
function unflatten(
440-
vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names}
441-
) where {variable_names}
442-
result = Dict{Symbol,Array}()
443-
start_idx = 1
444-
for name in variable_names
445-
size = variable_names_and_sizes[name]
446-
end_idx = start_idx + prod(size) - 1
447-
result[name] = reshape(vec[start_idx:end_idx], size...)
448-
start_idx = end_idx + 1
449-
end
426+
The Gibbs sampler operates in two main phases:
450427

451-
return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names]))
452-
end
453-
```
428+
1. Initialization:
429+
- Set up initial states for each conditional probability problem.
454430

455-
Some points worth noting:
431+
2. Iterative Sampling:
432+
For each iteration, the sampler performs a sweep over all conditional probability problems:
456433

457-
1. We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`.
458-
2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration.
459-
3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem.
460-
4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following:
461-
- condition on the values of all variables that are not in the current group
462-
- recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed
463-
- perform a step of the sampler for the conditional probability problem, and update the sampler state
464-
- update the `vi` with the new values from the sampler state
434+
a. Condition on other variables:
435+
- Fix the values of all variables except the current one.
436+
b. Update current variable:
437+
- Recompute the log probability of the current state, as other variables may have changed:
438+
- Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability.
439+
- Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability.
440+
- Perform a sampling step for the current conditional probability problem:
441+
- Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state.
442+
- Update the global trace:
443+
- Extract parameter values from the new state using `Base.vec(new_sub_state)`.
444+
- Incorporate these values into the overall Gibbs state trace.
465445

466-
The `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states.
446+
This process allows the Gibbs sampler to iteratively update each variable while conditioning on the others, gradually exploring the joint distribution of all variables.
467447

468-
Now we can use the Gibbs sampler to sample from the hierarchical normal model.
448+
Now we can use the Gibbs sampler to sample from the hierarchical Normal model.
469449

470450
First we generate some data,
471451

test/gibbs_example/gibbs.jl

Lines changed: 23 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
using AbstractMCMC: AbstractMCMC
22
using AbstractPPL: AbstractPPL
3-
using MCMCChains: Chains
43
using Random
54

6-
"""
7-
Gibbs(sampler_map::NamedTuple)
8-
9-
A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter.
10-
"""
115
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
6+
"Maps variables to their samplers."
127
sampler_map::T
138
end
149

@@ -28,74 +23,23 @@ struct GibbsTransition{ValuesNT<:NamedTuple}
2823
values::ValuesNT
2924
end
3025

31-
"""
32-
flatten(trace::NamedTuple)
33-
34-
Flatten all the values in the trace into a single vector. Variable names information is discarded.
35-
36-
# Examples
37-
38-
```jldoctest; setup = :(using AbstractMCMC: flatten)
39-
julia> flatten((a=ones(2), b=ones(2, 2)))
40-
6-element Vector{Float64}:
41-
1.0
42-
1.0
43-
1.0
44-
1.0
45-
1.0
46-
1.0
47-
48-
```
49-
"""
50-
function flatten(trace::NamedTuple)
51-
return reduce(vcat, vec.(values(trace)))
52-
end
53-
54-
"""
55-
unflatten(vec::AbstractVector, variable_names::Vector{Symbol}, variable_sizes::Vector{Tuple})
56-
57-
Reverse operation of flatten. Reshape the vector into the original arrays using size information.
58-
59-
# Examples
60-
61-
```jldoctest; setup = :(using AbstractMCMC: unflatten)
62-
julia> unflatten([1,2,3,4,5], (a=(2,), b=(3,)))
63-
(a = [1, 2], b = [3, 4, 5])
64-
65-
julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,)))
66-
(x = [1.0 3.0; 2.0 4.0], y = [5.0, 6.0])
67-
```
68-
"""
69-
function unflatten(
70-
vec::AbstractVector, variable_names_and_sizes::NamedTuple{variable_names}
71-
) where {variable_names}
72-
result = Dict{Symbol,Array}()
73-
start_idx = 1
74-
for name in variable_names
75-
size = variable_names_and_sizes[name]
76-
end_idx = start_idx + prod(size) - 1
77-
result[name] = reshape(vec[start_idx:end_idx], size...)
78-
start_idx = end_idx + 1
79-
end
80-
81-
return NamedTuple{variable_names}(Tuple([result[name] for name in variable_names]))
82-
end
83-
8426
"""
8527
update_trace(trace::NamedTuple, gibbs_state::GibbsState)
8628
8729
Update the trace with the values from the MCMC states of the sub-problems.
8830
"""
89-
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
90-
for parameter_variable in keys(gibbs_state.mcmc_states)
31+
function update_trace(
32+
trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT}
33+
) where {trace_names,TraceNT,StateNT,SizeNT}
34+
for parameter_variable in fieldnames(StateNT)
9135
sub_state = gibbs_state.mcmc_states[parameter_variable]
92-
sub_state_params = Base.vec(sub_state)
93-
unflattened_sub_state_params = unflatten(
94-
sub_state_params,
95-
NamedTuple{(parameter_variable,)}((
96-
gibbs_state.variable_sizes[parameter_variable],
97-
)),
36+
sub_state_params_values = Base.vec(sub_state)
37+
reshaped_sub_state_params_values = reshape(
38+
sub_state_params_values, gibbs_state.variable_sizes[parameter_variable]
9839
)
40+
unflattened_sub_state_params = NamedTuple{(parameter_variable,)}((
41+
reshaped_sub_state_params_values,
42+
))
9943
trace = merge(trace, unflattened_sub_state_params)
10044
end
10145
return trace
@@ -116,8 +60,7 @@ end
11660
function AbstractMCMC.step(
11761
rng::Random.AbstractRNG,
11862
logdensity_model::AbstractMCMC.LogDensityModel,
119-
sampler::Gibbs{Tsamplingmap},
120-
args...;
63+
sampler::Gibbs{Tsamplingmap};
12164
initial_params::NamedTuple,
12265
kwargs...,
12366
) where {Tsamplingmap}
@@ -133,30 +76,27 @@ function AbstractMCMC.step(
13376
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
13477
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
13578
)
136-
sub_problem_parameters_values = NamedTuple{(parameter_variable,)}((
137-
initial_params[parameter_variable],
138-
))
13979

14080
# LogDensityProblems' `logdensity` function expects a single vector of real numbers
14181
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
14282
# and unflatten after the sampling step
143-
flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values)
83+
flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable])
14484

85+
sub_logdensity_model = AbstractMCMC.LogDensityModel(
86+
AbstractPPL.condition(
87+
logdensity_model.logdensity, conditioning_variables_values
88+
),
89+
)
14590
sub_state = last(
14691
AbstractMCMC.step(
14792
rng,
148-
AbstractMCMC.LogDensityModel(
149-
AbstractPPL.condition(
150-
logdensity_model.logdensity, conditioning_variables_values
151-
),
152-
),
153-
sub_sampler,
154-
args...;
93+
sub_logdensity_model,
94+
sub_sampler;
15595
initial_params=flattened_sub_problem_parameters_values,
15696
kwargs...,
15797
),
15898
)
159-
(sub_state, Tuple(size(initial_params[parameter_variable])))
99+
(sub_state, size(initial_params[parameter_variable]))
160100
end
161101

162102
mcmc_states_tuple = first.(results)
@@ -177,8 +117,7 @@ function AbstractMCMC.step(
177117
rng::Random.AbstractRNG,
178118
logdensity_model::AbstractMCMC.LogDensityModel,
179119
sampler::Gibbs{Tsamplingmap},
180-
gibbs_state::GibbsState,
181-
args...;
120+
gibbs_state::GibbsState;
182121
kwargs...,
183122
) where {Tsamplingmap}
184123
trace = gibbs_state.trace
@@ -204,7 +143,7 @@ function AbstractMCMC.step(
204143
sub_state = (sub_state)(logp)
205144
sub_state = last(
206145
AbstractMCMC.step(
207-
rng, cond_logdensity_model, sub_sampler, sub_state, args...; kwargs...
146+
rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...
208147
),
209148
)
210149
trace = update_trace(trace, gibbs_state)

0 commit comments

Comments
 (0)