Skip to content

Commit 600d36c

Browse files
torfjeldesunxd3github-actions[bot]
authored
Apply suggestions from code review
Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d9f8585 commit 600d36c

File tree

2 files changed

+25
-64
lines changed

2 files changed

+25
-64
lines changed

docs/src/api.md

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ AbstractMCMC.chainsstack
112112

113113
To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
114114
```@docs
115-
AbstractMCMC.realize
116-
AbstractMCMC.realize!!
115+
AbstractMCMC.getparams
116+
AbstractMCMC.setparams!!
117117
```
118118
and optionally
119119
```@docs
@@ -125,7 +125,7 @@ These methods can also be useful for implementing samplers which wraps some inne
125125

126126
In a `MixtureSampler` we need two things:
127127
- `components`: collection of samplers.
128-
- `weights`: collection of weights representing the probability of chosing the corresponding sampler.
128+
- `weights`: collection of weights representing the probability of choosing the corresponding sampler.
129129

130130
```julia
131131
struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler
@@ -136,7 +136,6 @@ end
136136

137137
To implement the state, we need to keep track of a couple of things:
138138
- `index`: the index of the sampler used in this `step`.
139-
- `transition`: the transition resulting from this `step`.
140139
- `states`: the current states of _all_ the components.
141140
Two aspects of this might seem a bit strange:
142141
1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
@@ -146,11 +145,9 @@ The reason for (1) is that lots of samplers keep track of more than just the pre
146145

147146
For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`.
148147

149-
Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`.
150148
```julia
151-
struct MixtureState{T,S}
149+
struct MixtureState{S}
152150
index::Int
153-
transition::T
154151
states::S
155152
end
156153
```
@@ -162,15 +159,16 @@ X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
162159
\end{aligned}
163160
```
164161
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler.
165-
[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.
162+
[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.
166163

167164
If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:
168165

169166
```julia
170167
# Update the corresponding state, i.e. `state.states[i]`, using
171168
# the state and transition from the previous iteration.
172-
state_current = AbstractMCMC.updatestate!!(
173-
state.states[i], state.states[i_prev], state.transition
169+
state_current = AbstractMCMC.setparams!!(
170+
state.states[i],
171+
AbstractMCMC.getparams(state.states[i_prev]),
174172
)
175173

176174
# Take a `step` for this sampler using the updated state.
@@ -191,8 +189,9 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt
191189
# Update the corresponding state, i.e. `state.states[i]`, using
192190
# the state and transition from the previous iteration.
193191
i_prev = state.index
194-
state_current = AbstractMCMC.updatestate!!(
195-
model, state.states[i], state.states[i_prev], state.transition
192+
state_current = AbstractMCMC.setparams!!(
193+
state.states[i],
194+
AbstractMCMC.getparams(state.states[i_prev]),
196195
)
197196

198197
# Take a `step` for this sampler using the updated state.
@@ -217,7 +216,7 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt
217216
end
218217

219218
# Create the new `MixtureState`.
220-
state_new = MixtureState(i, transition, states_new)
219+
state_new = MixtureState(i, states_new)
221220

222221
return transition, state_new
223222
end
@@ -239,20 +238,14 @@ function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::Mixt
239238
# Extract states.
240239
states = map(last, transitions_and_states)
241240
# Create new `MixtureState`.
242-
state = MixtureState(i, transition, states)
241+
state = MixtureState(i, states)
243242

244243
return transition, state
245244
end
246245
```
247246

248-
Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `realize` and `realize!!`:
247+
Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`.
249248

250-
```julia
251-
function AbstractMCMC.updatestate!!(model, ::AdvancedMH.Transition, state_prev::AdvancedMH.Transition)
252-
# Let's `deepcopy` just to be certain.
253-
return deepcopy(state_prev)
254-
end
255-
```
256249

257250
To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do
258251

@@ -263,25 +256,3 @@ while ...
263256
transition, state = AbstractMCMC.step(rng, model, sampler, state)
264257
end
265258
```
266-
267-
As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`:
268-
269-
```julia
270-
struct ManyModels{M} <: AbstractMCMC.AbstractModel
271-
models::M
272-
end
273-
274-
Base.getindex(model::ManyModels, I...) = model.models[I...]
275-
```
276-
277-
Then the above `step` would just extract the `model` corresponding to the current sampler:
278-
279-
```julia
280-
# Take a `step` for this sampler using the updated state.
281-
transition, state_current = AbstractMCMC.step(
282-
rng, model[i], sampler_current, state_current;
283-
kwargs...
284-
)
285-
```
286-
287-
This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`.

src/AbstractMCMC.jl

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -80,35 +80,25 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr
8080
struct MCMCSerial <: AbstractMCMCEnsemble end
8181

8282
"""
83-
updatestate!!(model, state, transition_prev[, state_prev])
83+
getparams(state[; kwargs...])
8484
85-
Return new instance of `state` using information from `model`, `transition_prev` and, optionally, `state_prev`.
86-
87-
Defaults to `realize!!(state, realize(transition_prev))`.
85+
Retrieve the values of parameters from the sampler's `state` as a `Vector{<:Real}`.
8886
"""
89-
function updatestate!!(model, state, transition_prev, state_prev)
90-
return updatestate!!(state, transition_prev)
91-
end
92-
updatestate!!(model, state, transition) = realize!!(state, realize(transition))
87+
function getparams end
9388

9489
"""
95-
realize!!(state, realization)
96-
97-
Update the realization of the `state` with `realization` and return it.
90+
setparams!!(state, params)
9891
99-
If `state` can be updated in-place, it is expected that this function returns `state` with updated
100-
realize. Otherwise a new `state` object with the new `realization` is returned.
101-
"""
102-
function realize!! end
92+
Set the values of parameters in the sampler's `state` from a `Vector{<:Real}`.
10393
104-
"""
105-
realize(transition)
94+
This function should follow the `BangBang` interface: mutate `state` in-place if possible and
95+
return the mutated `state`. Otherwise, it should return a new `state` containing the updated parameters.
10696
107-
Return the realization of the random variables present in `transition`.
97+
Although not enforced, it should hold that `setparams!!(state, getparams(state)) == state`. In another
98+
word, the sampler should implement a consistent transformation between its internal representation
99+
and the vector representation of the parameter values.
108100
"""
109-
function realize end
110-
111-
101+
function setparams!! end
112102
include("samplingstats.jl")
113103
include("logging.jl")
114104
include("interface.jl")

0 commit comments

Comments
 (0)