Skip to content

Commit bceb510

Browse files
committed
fix doc example
1 parent 8d74889 commit bceb510

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

docs/src/state_interface.md

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Base.vec(state)
1717
This function takes the state and returns a vector of the parameter values stored in the state.
1818

1919
```julia
20-
state = StateType(state, logp)
20+
state = StateType(state::StateType, logp)
2121
```
2222

2323
This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability.
@@ -42,7 +42,17 @@ x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2})
4242

4343
We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data.
4444

45-
```julia
45+
```@example gibbs_example
46+
using AbstractMCMC: AbstractMCMC, LogDensityProblems # hide
47+
using Distributions # hide
48+
using Random # hide
49+
using AbstractMCMC: AbstractMCMC # hide
50+
using AbstractPPL: AbstractPPL # hide
51+
using BangBang: constructorof # hide
52+
using AbstractPPL: AbstractPPL
53+
```
54+
55+
```@example gibbs_example
4656
function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64})
4757
# mu is the mean
4858
# tau2 is the variance
@@ -71,7 +81,7 @@ end
7181

7282
To make using `LogDensityProblems` interface, we create a simple type for this model.
7383

74-
```julia
84+
```@example gibbs_example
7585
abstract type AbstractHierNormal end
7686
7787
struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal
@@ -89,15 +99,15 @@ end
8999

90100
where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and
91101

92-
```julia
102+
```@example gibbs_example
93103
function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple)
94104
return ConditionedHierNormal(hn.data, conditioned_values)
95105
end
96106
```
97107

98108
then we can simply write down the `LogDensityProblems` interface for this model.
99109

100-
```julia
110+
```@example gibbs_example
101111
function LogDensityProblems.logdensity(
102112
hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars},
103113
params::AbstractVector,
@@ -132,7 +142,7 @@ Although the interface doesn't force the sampler to implement `Transition` and `
132142

133143
Here we define some bare minimum types to represent the transitions and states.
134144

135-
```julia
145+
```@example gibbs_example
136146
struct MHTransition{T}
137147
params::Vector{T}
138148
end
@@ -145,7 +155,7 @@ end
145155

146156
Next we define the `state` interface functions mentioned at the beginning of this section.
147157

148-
```julia
158+
```@example gibbs_example
149159
# Interface 1: LogDensityProblems.logdensity
150160
# This function takes the logdensity function and the state (state is defined by the sampler package)
151161
# and returns the logdensity. It allows for optional recomputation of the log probability.
@@ -168,15 +178,14 @@ Base.vec(state::MHState) = state.params
168178
169179
# Interface 3: constructorof and MHState(state::MHState, logp::Float64)
170180
# This function allows the state to be updated with a new log probability.
171-
BangBang.constructorof(state::MHState{T}) where {T} = MHState
172181
function MHState(state::MHState, logp::Float64)
173182
return MHState(state.params, logp)
174183
end
175184
```
176185

177186
It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state.
178187

179-
```julia
188+
```@example gibbs_example
180189
"""
181190
RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
182191
@@ -264,7 +273,7 @@ end
264273

265274
At last, we can proceed to implement a very simple Gibbs sampler.
266275

267-
```julia
276+
```@example gibbs_example
268277
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
269278
"Maps variables to their samplers."
270279
sampler_map::T
@@ -440,7 +449,7 @@ The Gibbs sampler operates in two main phases:
440449
b. Update current variable:
441450
- Recompute the log probability of the current state, as other variables may have changed:
442451
- Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability.
443-
- Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability.
452+
- Update the state with `sub_state = constructorof(typeof(sub_state))(sub_state, logp)` to incorporate the new log probability.
444453
- Perform a sampling step for the current conditional probability problem:
445454
- Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state.
446455
- Update the global trace:

0 commit comments

Comments
 (0)