|
1 | 1 | using Distributions
|
2 | 2 |
|
3 |
| -struct MHTransition{T} |
| 3 | +struct MHState{T} |
4 | 4 | params::Vector{T}
|
| 5 | + logp::Float64 |
5 | 6 | end
|
6 | 7 |
|
7 |
| -struct MHState{T} |
| 8 | +struct MHTransition{T} |
8 | 9 | params::Vector{T}
|
9 |
| - logp::Float64 |
10 | 10 | end
|
11 | 11 |
|
12 | 12 | AbstractMCMC.get_params(state::MHState) = state.params
|
13 | 13 | AbstractMCMC.set_params!!(state::MHState, params) = MHState(params, state.logp)
|
14 | 14 | AbstractMCMC.get_logprob(state::MHState) = state.logp
|
15 | 15 | AbstractMCMC.set_logprob!!(state::MHState, logp) = MHState(state.params, logp)
|
16 | 16 |
|
17 |
| -struct RWMH <: AbstractMCMC.AbstractSampler |
| 17 | +struct RandomWalkMH <: AbstractMCMC.AbstractSampler |
18 | 18 | σ::Float64
|
19 | 19 | end
|
20 | 20 |
|
| 21 | +struct IndependentMH <: AbstractMCMC.AbstractSampler |
| 22 | + proposal_dist::Distributions.Distribution |
| 23 | +end |
| 24 | + |
21 | 25 | function AbstractMCMC.step(
|
22 | 26 | rng::AbstractRNG,
|
23 | 27 | logdensity_model::AbstractMCMC.LogDensityModel,
|
24 |
| - sampler::RWMH, |
| 28 | + sampler::Union{RandomWalkMH,IndependentMH}, |
25 | 29 | args...;
|
26 | 30 | initial_params,
|
27 | 31 | kwargs...,
|
|
36 | 40 | function AbstractMCMC.step(
|
37 | 41 | rng::AbstractRNG,
|
38 | 42 | logdensity_model::AbstractMCMC.LogDensityModel,
|
39 |
| - sampler::RWMH, |
| 43 | + sampler::Union{RandomWalkMH,IndependentMH}, |
40 | 44 | state::MHState,
|
41 | 45 | args...;
|
42 | 46 | kwargs...,
|
43 | 47 | )
|
44 | 48 | params = state.params
|
45 |
| - proposal_dist = MvNormal(zeros(length(params)), sampler.σ) |
46 |
| - proposal = params .+ rand(rng, proposal_dist) |
| 49 | + proposal_dist = |
| 50 | + sampler isa RandomWalkMH ? MvNormal(state.params, sampler.σ) : sampler.proposal_dist |
| 51 | + proposal = rand(rng, proposal_dist) |
47 | 52 | logp_proposal = only(
|
48 | 53 | LogDensityProblems.logdensity(logdensity_model.logdensity, proposal)
|
49 | 54 | )
|
50 | 55 |
|
51 |
| - log_acceptance_ratio = min(0, logp_proposal - AbstractMCMC.get_logprob(state)) |
52 |
| - |
53 |
| - if log(rand(rng)) < log_acceptance_ratio |
| 56 | + if log(rand(rng)) < |
| 57 | + compute_log_acceptance_ratio(sampler, state, proposal, logp_proposal) |
54 | 58 | return MHTransition(proposal), MHState(proposal, logp_proposal)
|
55 | 59 | else
|
56 |
| - return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state)) |
| 60 | + return MHTransition(params), MHState(params, state.logp) |
57 | 61 | end
|
58 | 62 | end
|
59 | 63 |
|
60 |
| -struct PriorMH <: AbstractMCMC.AbstractSampler |
61 |
| - prior_dist::Distributions.Distribution |
62 |
| -end |
63 |
| - |
64 |
| -function AbstractMCMC.step( |
65 |
| - rng::AbstractRNG, |
66 |
| - logdensity_model::AbstractMCMC.LogDensityModel, |
67 |
| - sampler::PriorMH, |
68 |
| - args...; |
69 |
| - initial_params, |
70 |
| - kwargs..., |
| 64 | +function compute_log_acceptance_ratio( |
| 65 | + ::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64 |
71 | 66 | )
|
72 |
| - return MHTransition(initial_params), |
73 |
| - MHState( |
74 |
| - initial_params, |
75 |
| - only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), |
76 |
| - ) |
| 67 | + return min(0, logp_proposal - AbstractMCMC.get_logprob(state)) |
77 | 68 | end
|
78 | 69 |
|
79 |
| -function AbstractMCMC.step( |
80 |
| - rng::AbstractRNG, |
81 |
| - logdensity_model::AbstractMCMC.LogDensityModel, |
82 |
| - sampler::PriorMH, |
83 |
| - state::MHState, |
84 |
| - args...; |
85 |
| - kwargs..., |
86 |
| -) |
87 |
| - params = AbstractMCMC.get_params(state) |
88 |
| - proposal_dist = sampler.prior_dist |
89 |
| - proposal = rand(rng, proposal_dist) |
90 |
| - logp_proposal = only( |
91 |
| - LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) |
92 |
| - ) |
93 |
| - |
94 |
| - log_acceptance_ratio = min( |
| 70 | +function compute_log_acceptance_ratio( |
| 71 | + sampler::IndependentMH, state::MHState, proposal::Vector{T}, logp_proposal::Float64 |
| 72 | +) where {T} |
| 73 | + return min( |
95 | 74 | 0,
|
96 |
| - logp_proposal - AbstractMCMC.get_logprob(state) + logpdf(proposal_dist, params) - |
97 |
| - logpdf(proposal_dist, proposal), |
| 75 | + logp_proposal - state.logp + logpdf(sampler.proposal_dist, state.params) - |
| 76 | + logpdf(sampler.proposal_dist, proposal), |
98 | 77 | )
|
99 |
| - |
100 |
| - if log(rand(rng)) < log_acceptance_ratio |
101 |
| - return MHTransition(proposal), MHState(proposal, logp_proposal) |
102 |
| - else |
103 |
| - return MHTransition(params), MHState(params, AbstractMCMC.get_logprob(state)) |
104 |
| - end |
105 | 78 | end
|
106 |
| - |
107 |
| -## tests |
108 |
| - |
109 |
| -# # for RWMH |
110 |
| -# # sample from Normal(10, 1) |
111 |
| -# struct NormalLogDensity end |
112 |
| -# LogDensityProblems.logdensity(l::NormalLogDensity, x) = logpdf(Normal(10, 1), only(x)) |
113 |
| -# LogDensityProblems.dimension(l::NormalLogDensity) = 1 |
114 |
| -# function LogDensityProblems.capabilities(::NormalLogDensity) |
115 |
| -# return LogDensityProblems.LogDensityOrder{1}() |
116 |
| -# end |
117 |
| - |
118 |
| -# # for PriorMH |
119 |
| -# # sample from Categorical([0.2, 0.5, 0.3]) |
120 |
| -# struct CategoricalLogDensity end |
121 |
| -# function LogDensityProblems.logdensity(l::CategoricalLogDensity, x) |
122 |
| -# return logpdf(Categorical([0.2, 0.6, 0.2]), only(x)) |
123 |
| -# end |
124 |
| -# LogDensityProblems.dimension(l::CategoricalLogDensity) = 1 |
125 |
| -# function LogDensityProblems.capabilities(::CategoricalLogDensity) |
126 |
| -# return LogDensityProblems.LogDensityOrder{0}() |
127 |
| -# end |
128 |
| - |
129 |
| -# ## |
130 |
| - |
131 |
| -# using StatsPlots |
132 |
| - |
133 |
| -# samples = AbstractMCMC.sample( |
134 |
| -# Random.default_rng(), NormalLogDensity(), RWMH(1), 100000; initial_params=[0.0] |
135 |
| -# ) |
136 |
| -# _samples = map(t -> only(t.params), samples) |
137 |
| - |
138 |
| -# histogram(_samples; normalize=:pdf, label="Samples", title="RWMH Sampling of Normal(10, 1)") |
139 |
| -# plot!(Normal(10, 1); linewidth=2, label="Ground Truth") |
140 |
| - |
141 |
| -# samples = AbstractMCMC.sample( |
142 |
| -# Random.default_rng(), |
143 |
| -# CategoricalLogDensity(), |
144 |
| -# PriorMH(product_distribution([Categorical([0.3, 0.3, 0.4])])), |
145 |
| -# 100000; |
146 |
| -# initial_params=[1], |
147 |
| -# ) |
148 |
| -# _samples = map(t -> only(t.params), samples) |
149 |
| - |
150 |
| -# histogram( |
151 |
| -# _samples; |
152 |
| -# normalize=:probability, |
153 |
| -# label="Samples", |
154 |
| -# title="MH From Prior Sampling of Categorical([0.3, 0.3, 0.4])", |
155 |
| -# ) |
156 |
| -# plot!(Categorical([0.2, 0.6, 0.2]); linewidth=2, label="Ground Truth") |
0 commit comments