Skip to content

Commit bce436d

Browse files
committed
added example for why updatestate!! is useful
1 parent d86499f commit bce436d

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

docs/src/api.md

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,128 @@ and optionally
8989
AbstractMCMC.updatestate!!(state, transition, state_prev)
9090
```
9191
These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers.
92+
93+
### Example: `MixtureSampler`
94+
95+
In a `MixtureSampler` we need two things:
96+
- `components`: collection of samplers.
97+
- `weights`: collection of weights representing the probability of chosing the corresponding sampler.
98+
99+
```julia
100+
struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler
101+
components::C
102+
weights::W
103+
end
104+
```
105+
106+
To implement the state, we need to keep track of a couple of things:
107+
- `index`: the index of the sampler used in this `step`.
108+
- `transition`: the transition resulting from this `step`.
109+
- `states`: the current states of _all_ the components.
110+
Two aspects of this might seem a bit strange:
111+
1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
112+
2. We need to put the `transition` from the `step` into the state.
113+
The reason for (1) is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc.
114+
For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. maybe the sampler is working in a transformed space but returns the samples in the original space, or maybe the sampler is even independent from the current realizations and the state is simply `nothing`. Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`.
115+
```julia
116+
struct MixtureState{T,S}
117+
index::Int
118+
transition::T
119+
states::S
120+
end
121+
```
122+
The `step` for a `MixtureSampler` is defined by the following generative process
123+
```math
124+
\begin{aligned}
125+
i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\
126+
X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
127+
\end{aligned}
128+
```
129+
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and `w_i` denotes the weight/probability of choosing the i-th sampler.
130+
[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler. If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:
131+
132+
```julia
133+
# Update the corresponding state, i.e. `state.states[i]`, using
134+
# the state and transition from the previous iteration.
135+
state_current = AbstractMCMC.updatestate!!(
136+
state.states[i], state.states[i_prev], state.transition
137+
)
138+
139+
# Take a `step` for this sampler using the updated state.
140+
transition, state_current = AbstractMCMC.step(
141+
rng, model, sampler_current, sampler_state;
142+
kwargs...
143+
)
144+
```
145+
146+
The full [`AbstractMCMC.step`](@ref) implementation would then be something like:
147+
148+
```julia
149+
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...)
150+
# Sample the component to use in this `step`.
151+
i = rand(Categorical(sampler.weights))
152+
sampler_current = sampler.components[i]
153+
154+
# Update the corresponding state, i.e. `state.states[i]`, using
155+
# the state and transition from the previous iteration.
156+
i_prev = state.index
157+
state_current = AbstractMCMC.updatestate!!(
158+
state.states[i], state.states[i_prev], state.transition
159+
)
160+
161+
# Take a `step` for this sampler using the updated state.
162+
transition, state_current = AbstractMCMC.step(
163+
rng, model, sampler_current, sampler_state;
164+
kwargs...
165+
)
166+
167+
# Create the new states.
168+
# NOTE: A better approach would be to use `Setfield.@set state.states[i] = ...`
169+
# but to keep this demo self-contained, we don't.
170+
states_new = ntuple(1:length(state.states)) do j
171+
if j != i
172+
state.states[i]
173+
else
174+
state_inner
175+
end
176+
end
177+
178+
# Create the new `MixtureState`.
179+
state_new = MixtureState(i, transition, states_new)
180+
181+
return transition, state_new
182+
end
183+
```
184+
185+
And for the initial [`AbstractMCMC.step`](@ref) we have:
186+
187+
```julia
188+
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...)
189+
# Initialize every state.
190+
transitions_and_states = map(sampler.components) do spl
191+
AbstractMCMC.step(rng, model, spl; kwargs...)
192+
end
193+
194+
# Sample the component to use this `step`.
195+
i = rand(Categorical(sampler.weights))
196+
# Extract the corresponding transition.
197+
transition = first(transition_and_states[i])
198+
# Extract states.
199+
states = map(last, transitions_and_states)
200+
# Create new `MixtureState`.
201+
state = MixtureState(i, transition, states)
202+
203+
return transition, state
204+
end
205+
```
206+
207+
To use `MixtureSampler`, one could then do something like
208+
209+
```julia
210+
sampler = MixtureSampler((0.1, 0.9), (sampler1, sampler2))
211+
transition, state = AbstractMCMC.step(rng, model, sampler)
212+
while ...
213+
transition, state = AbstractMCMC.step(rng, model, sampler, state)
214+
end
215+
```
216+

0 commit comments

Comments
 (0)