Skip to content

Commit b262ea9

Browse files
committed
update mh code
1 parent 280eaf1 commit b262ea9

File tree

3 files changed

+37
-115
lines changed

3 files changed

+37
-115
lines changed

docs/src/gibbs.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,14 @@ These are the functions that was used in the `recompute_logprob!!` function abov
162162
It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `get_logprob` to easily read the log probability of the current state.
163163

164164
```julia
165-
struct RWMH <: AbstractMCMC.AbstractSampler
165+
struct RandomWalkMH <: AbstractMCMC.AbstractSampler
166166
σ::Float64
167167
end
168168

169169
function AbstractMCMC.step(
170170
rng::AbstractRNG,
171171
logdensity_model::AbstractMCMC.LogDensityModel,
172-
sampler::RWMH,
172+
sampler::RandomWalkMH,
173173
args...;
174174
initial_params,
175175
kwargs...,
@@ -184,7 +184,7 @@ end
184184
function AbstractMCMC.step(
185185
rng::AbstractRNG,
186186
logdensity_model::AbstractMCMC.LogDensityModel,
187-
sampler::RWMH,
187+
sampler::RandomWalkMH,
188188
state::MHState,
189189
args...;
190190
kwargs...,
@@ -207,14 +207,14 @@ end
207207
```
208208

209209
```julia
210-
struct PriorMH <: AbstractMCMC.AbstractSampler
210+
struct IndependentMH <: AbstractMCMC.AbstractSampler
211211
prior_dist::Distribution
212212
end
213213

214214
function AbstractMCMC.step(
215215
rng::AbstractRNG,
216216
logdensity_model::AbstractMCMC.LogDensityModel,
217-
sampler::PriorMH,
217+
sampler::IndependentMH,
218218
args...;
219219
initial_params,
220220
kwargs...,
@@ -229,7 +229,7 @@ end
229229
function AbstractMCMC.step(
230230
rng::AbstractRNG,
231231
logdensity_model::AbstractMCMC.LogDensityModel,
232-
sampler::PriorMH,
232+
sampler::IndependentMH,
233233
state::MHState,
234234
args...;
235235
kwargs...,
@@ -384,8 +384,8 @@ samples = sample(
384384
hn,
385385
Gibbs(
386386
OrderedDict(
387-
(:mu,) => RWMH(1),
388-
(:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])),
387+
(:mu,) => RandomWalkMH(1),
388+
(:tau2,) => IndependentMH(product_distribution([InverseGamma(1, 1)])),
389389
),
390390
),
391391
100000;

test/gibbs_example/gibbs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ include("hier_normal.jl")
1515
samples = sample(
1616
hn,
1717
AbstractMCMC.Gibbs((
18-
mu=RWMH(1), tau2=PriorMH(product_distribution([InverseGamma(1, 1)]))
18+
mu=RandomWalkMH(1), tau2=IndependentMH(product_distribution([InverseGamma(1, 1)]))
1919
)),
2020
200000;
2121
initial_params=(mu=[0.0], tau2=[1.0]),
@@ -49,9 +49,9 @@ end
4949
# gmm,
5050
# Gibbs(
5151
# (
52-
# z = PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
53-
# w = PriorMH(Dirichlet(2, 1.0)),
54-
# μ = RWMH(1),
52+
# z = IndependentMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
53+
# w = IndependentMH(Dirichlet(2, 1.0)),
54+
# μ = RandomWalkMH(1),
5555
# ),
5656
# ),
5757
# 100000;

test/gibbs_example/mh.jl

Lines changed: 25 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
11
using Distributions
22

3-
struct MHTransition{T}
3+
struct MHState{T}
44
params::Vector{T}
5+
logp::Float64
56
end
67

7-
struct MHState{T}
8+
struct MHTransition{T}
89
params::Vector{T}
9-
logp::Float64
1010
end
1111

1212
AbstractMCMC.get_params(state::MHState) = state.params
1313
AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp)
1414
AbstractMCMC.get_logprob(state::MHState) = state.logp
1515
AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp)
1616

17-
struct RWMH <: AbstractMCMC.AbstractSampler
17+
struct RandomWalkMH <: AbstractMCMC.AbstractSampler
1818
σ::Float64
1919
end
2020

21+
struct IndependentMH <: AbstractMCMC.AbstractSampler
22+
proposal_dist::Distributions.Distribution
23+
end
24+
2125
function AbstractMCMC.step(
2226
rng::AbstractRNG,
2327
logdensity_model::AbstractMCMC.LogDensityModel,
24-
sampler::RWMH,
28+
sampler::Union{RandomWalkMH,IndependentMH},
2529
args...;
2630
initial_params,
2731
kwargs...,
@@ -36,121 +40,39 @@ end
3640
function AbstractMCMC.step(
3741
rng::AbstractRNG,
3842
logdensity_model::AbstractMCMC.LogDensityModel,
39-
sampler::RWMH,
43+
sampler::Union{RandomWalkMH,IndependentMH},
4044
state::MHState,
4145
args...;
4246
kwargs...,
4347
)
4448
params = state.params
45-
proposal_dist = MvNormal(zeros(length(params)), sampler.σ)
46-
proposal = params .+ rand(rng, proposal_dist)
49+
proposal_dist =
50+
sampler isa RandomWalkMH ? MvNormal(state.params, sampler.σ) : sampler.proposal_dist
51+
proposal = rand(rng, proposal_dist)
4752
logp_proposal = only(
4853
LogDensityProblems.logdensity(logdensity_model.logdensity, proposal)
4954
)
5055

51-
log_acceptance_ratio = min(0, logp_proposal - AbstractMCMC.get_logprob(state))
52-
53-
if log(rand(rng)) < log_acceptance_ratio
56+
if log(rand(rng)) <
57+
compute_log_acceptance_ratio(sampler, state, proposal, logp_proposal)
5458
return MHTransition(proposal), MHState(proposal, logp_proposal)
5559
else
56-
return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state))
60+
return MHTransition(params), MHState(params, state.logp)
5761
end
5862
end
5963

60-
struct PriorMH <: AbstractMCMC.AbstractSampler
61-
prior_dist::Distributions.Distribution
62-
end
63-
64-
function AbstractMCMC.step(
65-
rng::AbstractRNG,
66-
logdensity_model::AbstractMCMC.LogDensityModel,
67-
sampler::PriorMH,
68-
args...;
69-
initial_params,
70-
kwargs...,
64+
function compute_log_acceptance_ratio(
65+
::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64
7166
)
72-
return MHTransition(initial_params),
73-
MHState(
74-
initial_params,
75-
only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)),
76-
)
67+
return min(0, logp_proposal - AbstractMCMC.get_logprob(state))
7768
end
7869

79-
function AbstractMCMC.step(
80-
rng::AbstractRNG,
81-
logdensity_model::AbstractMCMC.LogDensityModel,
82-
sampler::PriorMH,
83-
state::MHState,
84-
args...;
85-
kwargs...,
86-
)
87-
params = AbstractMCMC.get_params(state)
88-
proposal_dist = sampler.prior_dist
89-
proposal = rand(rng, proposal_dist)
90-
logp_proposal = only(
91-
LogDensityProblems.logdensity(logdensity_model.logdensity, proposal)
92-
)
93-
94-
log_acceptance_ratio = min(
70+
function compute_log_acceptance_ratio(
71+
sampler::IndependentMH, state::MHState, proposal::Vector{T}, logp_proposal::Float64
72+
) where {T}
73+
return min(
9574
0,
96-
logp_proposal - AbstractMCMC.get_logprob(state) + logpdf(proposal_dist, params) -
97-
logpdf(proposal_dist, proposal),
75+
logp_proposal - state.logp + logpdf(sampler.proposal_dist, state.params) -
76+
logpdf(sampler.proposal_dist, proposal),
9877
)
99-
100-
if log(rand(rng)) < log_acceptance_ratio
101-
return MHTransition(proposal), MHState(proposal, logp_proposal)
102-
else
103-
return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state))
104-
end
10578
end
106-
107-
## tests
108-
109-
# # for RWMH
110-
# # sample from Normal(10, 1)
111-
# struct NormalLogDensity end
112-
# LogDensityProblems.logdensity(l::NormalLogDensity, x) = logpdf(Normal(10, 1), only(x))
113-
# LogDensityProblems.dimension(l::NormalLogDensity) = 1
114-
# function LogDensityProblems.capabilities(::NormalLogDensity)
115-
# return LogDensityProblems.LogDensityOrder{1}()
116-
# end
117-
118-
# # for PriorMH
119-
# # sample from Categorical([0.2, 0.5, 0.3])
120-
# struct CategoricalLogDensity end
121-
# function LogDensityProblems.logdensity(l::CategoricalLogDensity, x)
122-
# return logpdf(Categorical([0.2, 0.6, 0.2]), only(x))
123-
# end
124-
# LogDensityProblems.dimension(l::CategoricalLogDensity) = 1
125-
# function LogDensityProblems.capabilities(::CategoricalLogDensity)
126-
# return LogDensityProblems.LogDensityOrder{0}()
127-
# end
128-
129-
# ##
130-
131-
# using StatsPlots
132-
133-
# samples = AbstractMCMC.sample(
134-
# Random.default_rng(), NormalLogDensity(), RWMH(1), 100000; initial_params=[0.0]
135-
# )
136-
# _samples = map(t -> only(t.params), samples)
137-
138-
# histogram(_samples; normalize=:pdf, label="Samples", title="RWMH Sampling of Normal(10, 1)")
139-
# plot!(Normal(10, 1); linewidth=2, label="Ground Truth")
140-
141-
# samples = AbstractMCMC.sample(
142-
# Random.default_rng(),
143-
# CategoricalLogDensity(),
144-
# PriorMH(product_distribution([Categorical([0.3, 0.3, 0.4])])),
145-
# 100000;
146-
# initial_params=[1],
147-
# )
148-
# _samples = map(t -> only(t.params), samples)
149-
150-
# histogram(
151-
# _samples;
152-
# normalize=:probability,
153-
# label="Samples",
154-
# title="MH From Prior Sampling of Categorical([0.3, 0.3, 0.4])",
155-
# )
156-
# plot!(Categorical([0.2, 0.6, 0.2]); linewidth=2, label="Ground Truth")

0 commit comments

Comments
 (0)