|
| 1 | +# `state` Interface |
| 2 | + |
| 3 | +We encourage sampler packages to implement the following interface functions for the `state` type(s) they maintain: |
| 4 | + |
| 5 | +```@doc |
| 6 | +get_logprob |
| 7 | +set_logprob!! |
| 8 | +get_params |
| 9 | +set_params!! |
| 10 | +``` |
| 11 | + |
| 12 | +These function will provide a minimum interface to interact with the `state` datatype, which a sampler package doesn't have to expose. |
| 13 | + |
| 14 | +## Using the `state` Interface for block sampling within Gibbs |
| 15 | + |
| 16 | +In this sections, we will demonstrate how a `model` package may use this `state` interface to support a Gibbs sampler that can support blocking sampling using different inference algorithms. |
| 17 | + |
| 18 | +We consider a simple hierarchical model with a normal likelihood, with unknown mean and variance parameters. |
| 19 | + |
| 20 | +```math |
| 21 | +\begin{align} |
| 22 | +\mu &\sim \text{Normal}(0, 1) \\ |
| 23 | +\tau^2 &\sim \text{InverseGamma}(1, 1) \\ |
| 24 | +x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2}) |
| 25 | +\end{align} |
| 26 | +``` |
| 27 | + |
| 28 | +We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data. |
| 29 | + |
| 30 | +```julia |
| 31 | +function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) |
| 32 | + # mu is the mean |
| 33 | + # tau2 is the variance |
| 34 | + # x is data |
| 35 | + |
| 36 | + # μ ~ Normal(0, 1) |
| 37 | + # τ² ~ InverseGamma(1, 1) |
| 38 | + # xᵢ ~ Normal(μ, √τ²) |
| 39 | + |
| 40 | + logp = 0.0 |
| 41 | + mu = only(mu) |
| 42 | + tau2 = only(tau2) |
| 43 | + |
| 44 | + mu_prior = Normal(0, 1) |
| 45 | + logp += logpdf(mu_prior, mu) |
| 46 | + |
| 47 | + tau2_prior = InverseGamma(1, 1) |
| 48 | + logp += logpdf(tau2_prior, tau2) |
| 49 | + |
| 50 | + obs_prior = Normal(mu, sqrt(tau2)) |
| 51 | + logp += sum(logpdf(obs_prior, xi) for xi in x) |
| 52 | + |
| 53 | + return logp |
| 54 | +end |
| 55 | +``` |
| 56 | + |
| 57 | +To make using `LogDensityProblems` interface, we create a simple type for this model. |
| 58 | + |
| 59 | +```julia |
| 60 | +abstract type AbstractHierNormal end |
| 61 | + |
| 62 | +struct HierNormal <: AbstractHierNormal |
| 63 | + data::NamedTuple |
| 64 | +end |
| 65 | + |
| 66 | +struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal |
| 67 | + data::NamedTuple |
| 68 | + conditioned_values::NamedTuple{conditioned_vars} |
| 69 | +end |
| 70 | +``` |
| 71 | + |
| 72 | +where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and |
| 73 | + |
| 74 | +```julia |
| 75 | +function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) |
| 76 | + return ConditionedHierNormal(hn.data, conditioned_values) |
| 77 | +end |
| 78 | +``` |
| 79 | + |
| 80 | +then we can simply write down the `LogDensityProblems` interface for this model. |
| 81 | + |
| 82 | +```julia |
| 83 | +function LogDensityProblems.logdensity( |
| 84 | + hn::ConditionedHierNormal{names}, params::AbstractVector |
| 85 | +) where {names} |
| 86 | + if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2 |
| 87 | + return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x) |
| 88 | + elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu |
| 89 | + return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x) |
| 90 | + else |
| 91 | + error("Unsupported conditioning configuration.") |
| 92 | + end |
| 93 | +end |
| 94 | + |
| 95 | +function LogDensityProblems.capabilities(::HierNormal) |
| 96 | + return LogDensityProblems.LogDensityOrder{0}() |
| 97 | +end |
| 98 | + |
| 99 | +function LogDensityProblems.capabilities(::ConditionedHierNormal) |
| 100 | + return LogDensityProblems.LogDensityOrder{0}() |
| 101 | +end |
| 102 | +``` |
| 103 | + |
| 104 | +the model should also define a function that allows the recomputation of the log probability given a sampler state. |
| 105 | +The reason for this is that, when we break down the joint probability into conditional probabilities, individual conditional probability problems are conditional on the values of the other variables. |
| 106 | +Between the Gibbs sampler sweeps, the values of the variables may change, and we need to recompute the log probability of the current state. |
| 107 | + |
| 108 | +A recomputation function could use the `state` interface to return a new state with the updated log probability. |
| 109 | +E.g. |
| 110 | + |
| 111 | +```julia |
| 112 | +function recompute_logprob!!(hn::ConditionedHierNormal, vals, state) |
| 113 | + return AbstractMCMC.set_logprob!!(state, LogDensityProblems.logdensity(hn, vals)) |
| 114 | +end |
| 115 | +``` |
| 116 | + |
| 117 | +where the model doesn't need to know the details of the `state` type, as long as it can access the `log_joint` function. |
| 118 | + |
| 119 | +## Sampler Packages |
| 120 | + |
| 121 | +To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal. |
| 122 | + |
| 123 | +Although the interface doesn't force the sampler to implement `Transition` and `State` types, in practice, it has been the convention to do so. |
| 124 | + |
| 125 | +Here we define some bare minimum types to represent the transitions and states. |
| 126 | + |
| 127 | +```julia |
| 128 | +struct MHTransition{T} |
| 129 | + params::Vector{T} |
| 130 | +end |
| 131 | + |
| 132 | +struct MHState{T} |
| 133 | + params::Vector{T} |
| 134 | + logp::Float64 |
| 135 | +end |
| 136 | +``` |
| 137 | + |
| 138 | +Next we define the four `state` interface functions. |
| 139 | + |
| 140 | +```julia |
| 141 | +AbstractMCMC.get_params(state::MHState) = state.params |
| 142 | +AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp) |
| 143 | +AbstractMCMC.get_logprob(state::MHState) = state.logp |
| 144 | +AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp) |
| 145 | +``` |
| 146 | + |
| 147 | +These are the functions that was used in the `recompute_logprob!!` function above. |
| 148 | + |
| 149 | +It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `get_logprob` to easily read the log probability of the current state. |
| 150 | + |
| 151 | +```julia |
| 152 | +struct RWMH <: AbstractMCMC.AbstractSampler |
| 153 | + σ::Float64 |
| 154 | +end |
| 155 | + |
| 156 | +function AbstractMCMC.step( |
| 157 | + rng::AbstractRNG, |
| 158 | + logdensity_model::AbstractMCMC.LogDensityModel, |
| 159 | + sampler::RWMH, |
| 160 | + args...; |
| 161 | + initial_params, |
| 162 | + kwargs..., |
| 163 | +) |
| 164 | + return MHTransition(initial_params), |
| 165 | + MHState( |
| 166 | + initial_params, |
| 167 | + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), |
| 168 | + ) |
| 169 | +end |
| 170 | + |
| 171 | +function AbstractMCMC.step( |
| 172 | + rng::AbstractRNG, |
| 173 | + logdensity_model::AbstractMCMC.LogDensityModel, |
| 174 | + sampler::RWMH, |
| 175 | + state::MHState, |
| 176 | + args...; |
| 177 | + kwargs..., |
| 178 | +) |
| 179 | + params = state.params |
| 180 | + proposal_dist = MvNormal(zeros(length(params)), sampler.σ) |
| 181 | + proposal = params .+ rand(rng, proposal_dist) |
| 182 | + logp_proposal = only( |
| 183 | + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) |
| 184 | + ) |
| 185 | + |
| 186 | + log_acceptance_ratio = min(0, logp_proposal - get_logprob(state)) |
| 187 | + |
| 188 | + if log(rand(rng)) < log_acceptance_ratio |
| 189 | + return MHTransition(proposal), MHState(proposal, logp_proposal) |
| 190 | + else |
| 191 | + return MHTransition(params), MHState(params, get_logprob(state)) |
| 192 | + end |
| 193 | +end |
| 194 | +``` |
| 195 | + |
| 196 | +```julia |
| 197 | +struct PriorMH <: AbstractMCMC.AbstractSampler |
| 198 | + prior_dist::Distribution |
| 199 | +end |
| 200 | + |
| 201 | +function AbstractMCMC.step( |
| 202 | + rng::AbstractRNG, |
| 203 | + logdensity_model::AbstractMCMC.LogDensityModel, |
| 204 | + sampler::PriorMH, |
| 205 | + args...; |
| 206 | + initial_params, |
| 207 | + kwargs..., |
| 208 | +) |
| 209 | + return MHTransition(initial_params), |
| 210 | + MHState( |
| 211 | + initial_params, |
| 212 | + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), |
| 213 | + ) |
| 214 | +end |
| 215 | + |
| 216 | +function AbstractMCMC.step( |
| 217 | + rng::AbstractRNG, |
| 218 | + logdensity_model::AbstractMCMC.LogDensityModel, |
| 219 | + sampler::PriorMH, |
| 220 | + state::MHState, |
| 221 | + args...; |
| 222 | + kwargs..., |
| 223 | +) |
| 224 | + params = get_params(state) |
| 225 | + proposal_dist = sampler.prior_dist |
| 226 | + proposal = rand(rng, proposal_dist) |
| 227 | + logp_proposal = only( |
| 228 | + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) |
| 229 | + ) |
| 230 | + |
| 231 | + log_acceptance_ratio = min( |
| 232 | + 0, |
| 233 | + logp_proposal - get_logprob(state) + logpdf(proposal_dist, params) - |
| 234 | + logpdf(proposal_dist, proposal), |
| 235 | + ) |
| 236 | + |
| 237 | + if log(rand(rng)) < log_acceptance_ratio |
| 238 | + return MHTransition(proposal), MHState(proposal, logp_proposal) |
| 239 | + else |
| 240 | + return MHTransition(params), MHState(params, get_logprob(state)) |
| 241 | + end |
| 242 | +end |
| 243 | +``` |
| 244 | + |
| 245 | +At last, we can proceed to implement the Gibbs sampler. |
| 246 | + |
| 247 | +```julia |
| 248 | +struct Gibbs <: AbstractMCMC.AbstractSampler |
| 249 | + sampler_map::OrderedDict |
| 250 | +end |
| 251 | + |
| 252 | +struct GibbsState |
| 253 | + vi::NamedTuple |
| 254 | + states::OrderedDict |
| 255 | +end |
| 256 | + |
| 257 | +struct GibbsTransition |
| 258 | + values::NamedTuple |
| 259 | +end |
| 260 | + |
| 261 | +function AbstractMCMC.step( |
| 262 | + rng::AbstractRNG, |
| 263 | + logdensity_model::AbstractMCMC.LogDensityModel, |
| 264 | + spl::Gibbs, |
| 265 | + args...; |
| 266 | + initial_params::NamedTuple, |
| 267 | + kwargs..., |
| 268 | +) |
| 269 | + states = OrderedDict() |
| 270 | + for group in keys(spl.sampler_map) |
| 271 | + sub_spl = spl.sampler_map[group] |
| 272 | + |
| 273 | + vars_to_be_conditioned_on = setdiff(keys(initial_params), group) |
| 274 | + cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}( |
| 275 | + Tuple([initial_params[g] for g in vars_to_be_conditioned_on]) |
| 276 | + ) |
| 277 | + params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group])) |
| 278 | + sub_state = last( |
| 279 | + AbstractMCMC.step( |
| 280 | + rng, |
| 281 | + AbstractMCMC.LogDensityModel( |
| 282 | + condition(logdensity_model.logdensity, cond_val) |
| 283 | + ), |
| 284 | + sub_spl, |
| 285 | + args...; |
| 286 | + initial_params=flatten(params_val), |
| 287 | + kwargs..., |
| 288 | + ), |
| 289 | + ) |
| 290 | + states[group] = sub_state |
| 291 | + end |
| 292 | + return GibbsTransition(initial_params), GibbsState(initial_params, states) |
| 293 | +end |
| 294 | + |
| 295 | +function AbstractMCMC.step( |
| 296 | + rng::AbstractRNG, |
| 297 | + logdensity_model::AbstractMCMC.LogDensityModel, |
| 298 | + spl::Gibbs, |
| 299 | + state::GibbsState, |
| 300 | + args...; |
| 301 | + kwargs..., |
| 302 | +) |
| 303 | + vi = state.vi |
| 304 | + for group in keys(spl.sampler_map) |
| 305 | + for (group, sub_state) in state.states |
| 306 | + vi = merge(vi, unflatten(get_params(sub_state), group)) |
| 307 | + end |
| 308 | + sub_spl = spl.sampler_map[group] |
| 309 | + sub_state = state.states[group] |
| 310 | + group_complement = setdiff(keys(vi), group) |
| 311 | + cond_val = NamedTuple{Tuple(group_complement)}( |
| 312 | + Tuple([vi[g] for g in group_complement]) |
| 313 | + ) |
| 314 | + cond_logdensity = condition(logdensity_model.logdensity, cond_val) |
| 315 | + sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state) |
| 316 | + sub_state = last( |
| 317 | + AbstractMCMC.step( |
| 318 | + rng, |
| 319 | + AbstractMCMC.LogDensityModel(cond_logdensity), |
| 320 | + sub_spl, |
| 321 | + sub_state, |
| 322 | + args...; |
| 323 | + kwargs..., |
| 324 | + ), |
| 325 | + ) |
| 326 | + state.states[group] = sub_state |
| 327 | + end |
| 328 | + for (group, sub_state) in state.states |
| 329 | + vi = merge(vi, unflatten(get_params(sub_state), group)) |
| 330 | + end |
| 331 | + return GibbsTransition(vi), GibbsState(vi, state.states) |
| 332 | +end |
| 333 | +``` |
| 334 | + |
| 335 | +Some points worth noting: |
| 336 | + |
| 337 | +1. We are using `OrderedDict` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. |
| 338 | +2. For each conditional probability problem, we need to store the sampler states for each variable group and also the values of all the variables from last iteration. |
| 339 | +3. The first step of the Gibbs sampler is to setup the states for each conditional probability problem. |
| 340 | +4. In the following steps of the Gibbs sampler, it will do a sweep over all the conditional probability problems, and update the sampler states for each problem. In each step of the sweep, it will do the following: |
| 341 | + - first update the values from the last step of the sweep into the `vi`, which stores the values of all variables at the moment of the Gibbs sweep. |
| 342 | + - condition on the values of all variables that are not in the current group |
| 343 | + - recompute the log probability of the current state, because the values of the variables that are not in the current group may have changed |
| 344 | + - perform a step of the sampler for the conditional probability problem, and update the sampler state |
| 345 | + - update the `vi` with the new values from the sampler state |
| 346 | + |
| 347 | +Again, the `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states. |
0 commit comments