|
48 | 48 | StaticMH(d) = MetropolisHastings(StaticProposal(d))
|
49 | 49 | RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
|
50 | 50 |
|
51 |
| -# default function without RNG |
52 |
| -propose(spl::MetropolisHastings, args...) = propose(Random.GLOBAL_RNG, spl, args...) |
53 |
| - |
54 |
| -# Propose from a vector of proposals |
55 |
| -function propose( |
56 |
| - rng::Random.AbstractRNG, |
57 |
| - spl::MetropolisHastings{<:AbstractArray}, |
58 |
| - model::DensityModel |
59 |
| -) |
60 |
| - proposal = map(p -> propose(rng, p, model), spl.proposal) |
61 |
| - return Transition(model, proposal) |
62 |
| -end |
63 |
| - |
64 |
| -function propose( |
65 |
| - rng::Random.AbstractRNG, |
66 |
| - spl::MetropolisHastings{<:AbstractArray}, |
67 |
| - model::DensityModel, |
68 |
| - params_prev::Transition |
69 |
| -) |
70 |
| - proposal = map(spl.proposal, params_prev.params) do p, params |
71 |
| - propose(rng, p, model, params) |
72 |
| - end |
73 |
| - return Transition(model, proposal) |
74 |
| -end |
75 |
| - |
76 |
| -# Make a proposal from one Proposal struct. |
77 |
| -function propose( |
78 |
| - rng::Random.AbstractRNG, |
79 |
| - spl::MetropolisHastings{<:Proposal}, |
80 |
| - model::DensityModel |
81 |
| -) |
82 |
| - proposal = propose(rng, spl.proposal, model) |
83 |
| - return Transition(model, proposal) |
84 |
| -end |
85 |
| - |
86 |
| -function propose( |
87 |
| - rng::Random.AbstractRNG, |
88 |
| - spl::MetropolisHastings{<:Proposal}, |
89 |
| - model::DensityModel, |
90 |
| - params_prev::Transition |
91 |
| -) |
92 |
| - proposal = propose(rng, spl.proposal, model, params_prev.params) |
93 |
| - return Transition(model, proposal) |
94 |
| -end |
95 |
| - |
96 |
| -# Make a proposal from a NamedTuple of Proposal. |
97 |
| -function propose( |
98 |
| - rng::Random.AbstractRNG, |
99 |
| - spl::MetropolisHastings{<:NamedTuple}, |
100 |
| - model::DensityModel |
101 |
| -) |
102 |
| - proposal = _propose(rng, spl.proposal, model) |
103 |
| - return Transition(model, proposal) |
| 51 | +function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModel) |
| 52 | + return propose(rng, sampler.proposal, model) |
104 | 53 | end
|
105 |
| - |
106 | 54 | function propose(
|
107 | 55 | rng::Random.AbstractRNG,
|
108 |
| - spl::MetropolisHastings{<:NamedTuple}, |
| 56 | + sampler::MHSampler, |
109 | 57 | model::DensityModel,
|
110 |
| - params_prev::Transition |
| 58 | + transition_prev::Transition, |
111 | 59 | )
|
112 |
| - proposal = _propose(rng, spl.proposal, model, params_prev.params) |
113 |
| - return Transition(model, proposal) |
| 60 | + return propose(rng, sampler.proposal, model, transition_prev.params) |
114 | 61 | end
|
115 | 62 |
|
116 |
| -@generated function _propose( |
117 |
| - rng::Random.AbstractRNG, |
118 |
| - proposal::NamedTuple{names}, |
119 |
| - model::DensityModel |
120 |
| -) where {names} |
121 |
| - isempty(names) && return :(NamedTuple()) |
122 |
| - expr = Expr(:tuple) |
123 |
| - expr.args = Any[:($name = propose(rng, proposal.$name, model)) for name in names] |
124 |
| - return expr |
| 63 | +function transition(sampler::MHSampler, model::DensityModel, params) |
| 64 | + logdensity = AdvancedMH.logdensity(model, params) |
| 65 | + return transition(sampler, model, params, logdensity) |
125 | 66 | end
|
126 |
| - |
127 |
| -@generated function _propose( |
128 |
| - rng::Random.AbstractRNG, |
129 |
| - proposal::NamedTuple{names}, |
130 |
| - model::DensityModel, |
131 |
| - params_prev::NamedTuple |
132 |
| -) where {names} |
133 |
| - isempty(names) && return :(NamedTuple()) |
134 |
| - expr = Expr(:tuple) |
135 |
| - expr.args = Any[ |
136 |
| - :($name = propose(rng, proposal.$name, model, params_prev.$name)) for name in names |
137 |
| - ] |
138 |
| - return expr |
| 67 | +function transition(sampler::MHSampler, model::DensityModel, params, logdensity::Real) |
| 68 | + return Transition(params, logdensity) |
139 | 69 | end
|
140 | 70 |
|
141 |
| -transition(sampler, model, params) = transition(model, params) |
142 |
| -transition(model, params) = Transition(model, params) |
143 |
| - |
144 | 71 | # Define the first sampling step.
|
145 | 72 | # Return a 2-tuple consisting of the initial sample and the initial state.
|
146 | 73 | # In this case they are identical.
|
147 | 74 | function AbstractMCMC.step(
|
148 | 75 | rng::Random.AbstractRNG,
|
149 | 76 | model::DensityModel,
|
150 |
| - spl::MHSampler; |
| 77 | + sampler::MHSampler; |
151 | 78 | init_params=nothing,
|
152 | 79 | kwargs...
|
153 | 80 | )
|
154 |
| - if init_params === nothing |
155 |
| - transition = propose(rng, spl, model) |
156 |
| - else |
157 |
| - transition = AdvancedMH.transition(spl, model, init_params) |
158 |
| - end |
159 |
| - |
| 81 | + params = init_params === nothing ? propose(rng, sampler, model) : init_params |
| 82 | + transition = AdvancedMH.transition(sampler, model, params) |
160 | 83 | return transition, transition
|
161 | 84 | end
|
162 | 85 |
|
|
167 | 90 | function AbstractMCMC.step(
|
168 | 91 | rng::Random.AbstractRNG,
|
169 | 92 | model::DensityModel,
|
170 |
| - spl::MHSampler, |
171 |
| - params_prev::AbstractTransition; |
| 93 | + sampler::MHSampler, |
| 94 | + transition_prev::AbstractTransition; |
172 | 95 | kwargs...
|
173 | 96 | )
|
174 | 97 | # Generate a new proposal.
|
175 |
| - params = propose(rng, spl, model, params_prev) |
| 98 | + candidate = propose(rng, sampler, model, transition_prev) |
176 | 99 |
|
177 |
| - # Calculate the log acceptance probability. |
178 |
| - logα = logdensity(model, params) - logdensity(model, params_prev) + |
179 |
| - logratio_proposal_density(spl, params_prev, params) |
| 100 | + # Calculate the log acceptance probability and the log density of the candidate. |
| 101 | + logdensity_candidate = logdensity(model, candidate) |
| 102 | + logα = logdensity_candidate - logdensity(model, transition_prev) + |
| 103 | + logratio_proposal_density(sampler, transition_prev, candidate) |
180 | 104 |
|
181 | 105 | # Decide whether to return the previous params or the new one.
|
182 |
| - if -Random.randexp(rng) < logα |
183 |
| - return params, params |
| 106 | + transition = if -Random.randexp(rng) < logα |
| 107 | + AdvancedMH.transition(sampler, model, candidate, logdensity_candidate) |
184 | 108 | else
|
185 |
| - return params_prev, params_prev |
| 109 | + transition_prev |
186 | 110 | end
|
| 111 | + |
| 112 | + return transition, transition |
187 | 113 | end
|
188 | 114 |
|
189 | 115 | function logratio_proposal_density(
|
190 |
| - sampler::MetropolisHastings, params_prev::Transition, params::Transition |
| 116 | + sampler::MetropolisHastings, transition_prev::AbstractTransition, candidate |
191 | 117 | )
|
192 |
| - return logratio_proposal_density(sampler.proposal, params_prev.params, params.params) |
| 118 | + return logratio_proposal_density(sampler.proposal, transition_prev.params, candidate) |
193 | 119 | end
|
0 commit comments