Skip to content

Commit 941c046

Browse files
authored
Add logratio_proposal_density and remove is_symmetric_proposal (#54)
* Add `logratio_proposal_density` and remove `is_symmetric_proposal` * Bump version * Add `issymmetric` for `StaticProposal` and add aliases * Fix type inference problems * Remove accidentally included code
1 parent 7a02255 commit 941c046

File tree

7 files changed

+198
-101
lines changed

7 files changed

+198
-101
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.5.9"
3+
version = "0.6.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/AdvancedMH.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,18 @@ using Requires
88
import Random
99

1010
# Exports
11-
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal,
12-
RandomWalkProposal, Ensemble, StretchProposal, MALA
11+
export
12+
MetropolisHastings,
13+
DensityModel,
14+
RWMH,
15+
StaticMH,
16+
StaticProposal,
17+
SymmetricStaticProposal,
18+
RandomWalkProposal,
19+
SymmetricRandomWalkProposal,
20+
Ensemble,
21+
StretchProposal,
22+
MALA
1323

1424
# Reexports
1525
export sample, MCMCThreads, MCMCDistributed

src/MALA.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ function q(
4545
return q(spl.proposal(-t_cond.gradient), t.params, t_cond.params)
4646
end
4747

48+
function logratio_proposal_density(
49+
sampler::MALA{<:Proposal}, state::GradientTransition, candidate::GradientTransition
50+
)
51+
return q(sampler, state, candidate) - q(sampler, candidate, state)
52+
end
4853

4954
"""
5055
logdensity_and_gradient(model::DensityModel, params)

src/mh-core.jl

Lines changed: 8 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -138,37 +138,6 @@ end
138138
return expr
139139
end
140140

141-
# Evaluate the likelihood of t conditional on t_cond.
142-
function q(
143-
spl::MetropolisHastings{<:AbstractArray},
144-
t::Transition,
145-
t_cond::Transition
146-
)
147-
# mapreduce with multiple iterators requires Julia 1.2 or later
148-
return mapreduce(+, 1:length(spl.proposal)) do i
149-
q(spl.proposal[i], t.params[i], t_cond.params[i])
150-
end
151-
end
152-
153-
function q(
154-
spl::MetropolisHastings{<:Proposal},
155-
t::Transition,
156-
t_cond::Transition
157-
)
158-
return q(spl.proposal, t.params, t_cond.params)
159-
end
160-
161-
function q(
162-
spl::MetropolisHastings{<:NamedTuple},
163-
t::Transition,
164-
t_cond::Transition
165-
)
166-
# mapreduce with multiple iterators requires Julia 1.2 or later
167-
return mapreduce(+, keys(t.params)) do k
168-
q(spl.proposal[k], t.params[k], t_cond.params[k])
169-
end
170-
end
171-
172141
transition(sampler, model, params) = transition(model, params)
173142
transition(model, params) = Transition(model, params)
174143

@@ -191,31 +160,6 @@ function AbstractMCMC.step(
191160
return transition, transition
192161
end
193162

194-
"""
195-
is_symmetric_proposal(proposal)::Bool
196-
197-
Implementing this for a custom proposal will allow `AbstractMCMC.step` to avoid
198-
computing the "Hastings" part of the Metropolis-Hasting log acceptance
199-
probability (if the proposal is indeed symmetric). By default,
200-
`is_symmetric_proposal(proposal)` returns `false`. The user is responsible for
201-
determining whether a custom proposal distribution is indeed symmetric. As
202-
noted in `MetropolisHastings`, `proposal` is a `Proposal`, `NamedTuple` of
203-
`Proposal`, or `Array{Proposal}` in the shape of your data.
204-
"""
205-
is_symmetric_proposal(proposal) = false
206-
207-
# The following univariate random walk proposals are symmetric.
208-
is_symmetric_proposal(::RandomWalkProposal{<:Normal}) = true
209-
is_symmetric_proposal(::RandomWalkProposal{<:MvNormal}) = true
210-
is_symmetric_proposal(::RandomWalkProposal{<:TDist}) = true
211-
is_symmetric_proposal(::RandomWalkProposal{<:Cauchy}) = true
212-
213-
# The following multivariate random walk proposals are symmetric.
214-
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Normal}}) = true
215-
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:MvNormal}}) = true
216-
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:TDist}}) = true
217-
is_symmetric_proposal(::RandomWalkProposal{<:AbstractArray{<:Cauchy}}) = true
218-
219163
# Define the other sampling steps.
220164
# Return a 2-tuple consisting of the next sample and the the next state.
221165
# In this case they are identical, and either a new proposal (if accepted)
@@ -231,12 +175,8 @@ function AbstractMCMC.step(
231175
params = propose(rng, spl, model, params_prev)
232176

233177
# Calculate the log acceptance probability.
234-
logα = logdensity(model, params) - logdensity(model, params_prev)
235-
236-
# Compute Hastings portion of ratio if proposal is not symmetric.
237-
if !is_symmetric_proposal(spl.proposal)
238-
logα += q(spl, params_prev, params) - q(spl, params, params_prev)
239-
end
178+
logα = logdensity(model, params) - logdensity(model, params_prev) +
179+
logratio_proposal_density(spl, params_prev, params)
240180

241181
# Decide whether to return the previous params or the new one.
242182
if -Random.randexp(rng) < logα
@@ -245,3 +185,9 @@ function AbstractMCMC.step(
245185
return params_prev, params_prev
246186
end
247187
end
188+
189+
function logratio_proposal_density(
190+
sampler::MetropolisHastings, params_prev::Transition, params::Transition
191+
)
192+
return logratio_proposal_density(sampler.proposal, params_prev.params, params.params)
193+
end

src/proposal.jl

Lines changed: 102 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
abstract type Proposal{P} end
22

3-
struct StaticProposal{P} <: Proposal{P}
3+
struct StaticProposal{issymmetric,P} <: Proposal{P}
44
proposal::P
55
end
6+
const SymmetricStaticProposal{P} = StaticProposal{true,P}
67

7-
struct RandomWalkProposal{P} <: Proposal{P}
8+
StaticProposal(proposal) = StaticProposal{false}(proposal)
9+
function StaticProposal{issymmetric}(proposal) where {issymmetric}
10+
return StaticProposal{issymmetric,typeof(proposal)}(proposal)
11+
end
12+
13+
struct RandomWalkProposal{issymmetric,P} <: Proposal{P}
814
proposal::P
915
end
16+
const SymmetricRandomWalkProposal{P} = RandomWalkProposal{true,P}
17+
18+
RandomWalkProposal(proposal) = RandomWalkProposal{false}(proposal)
19+
function RandomWalkProposal{issymmetric}(proposal) where {issymmetric}
20+
return RandomWalkProposal{issymmetric,typeof(proposal)}(proposal)
21+
end
1022

1123
# Random draws
1224
Base.rand(p::Proposal, args...) = rand(Random.GLOBAL_RNG, p, args...)
@@ -26,24 +38,28 @@ end
2638
# Random Walk #
2739
###############
2840

29-
function propose(rng::Random.AbstractRNG, p::RandomWalkProposal, m::DensityModel)
30-
return propose(rng, StaticProposal(p.proposal), m)
41+
function propose(
42+
rng::Random.AbstractRNG,
43+
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
44+
::DensityModel
45+
) where {issymmetric}
46+
return rand(rng, proposal)
3147
end
3248

3349
function propose(
3450
rng::Random.AbstractRNG,
35-
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
36-
model::DensityModel,
51+
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
52+
model::DensityModel,
3753
t
38-
)
54+
) where {issymmetric}
3955
return t + rand(rng, proposal)
4056
end
4157

4258
function q(
43-
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
59+
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
4460
t,
4561
t_cond
46-
)
62+
) where {issymmetric}
4763
return logpdf(proposal, t - t_cond)
4864
end
4965

@@ -53,18 +69,18 @@ end
5369

5470
function propose(
5571
rng::Random.AbstractRNG,
56-
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
72+
proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}},
5773
model::DensityModel,
5874
t=nothing
59-
)
75+
) where {issymmetric}
6076
return rand(rng, proposal)
6177
end
6278

6379
function q(
64-
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
80+
proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}},
6581
t,
6682
t_cond
67-
)
83+
) where {issymmetric}
6884
return logpdf(proposal, t)
6985
end
7086

@@ -73,10 +89,14 @@ end
7389
############
7490

7591
# function definition with abstract types requires Julia 1.3 or later
76-
for T in (StaticProposal, RandomWalkProposal)
92+
for T in (:StaticProposal, :RandomWalkProposal)
7793
@eval begin
78-
(p::$T{<:Function})() = $T(p.proposal())
79-
(p::$T{<:Function})(t) = $T(p.proposal(t))
94+
function (p::$T{issymmetric,<:Function})() where {issymmetric}
95+
return $T{issymmetric}(p.proposal())
96+
end
97+
function (p::$T{issymmetric,<:Function})(t) where {issymmetric}
98+
return $T{issymmetric}(p.proposal(t))
99+
end
80100
end
81101
end
82102

@@ -103,4 +123,69 @@ function q(
103123
t_cond
104124
)
105125
return q(proposal(t_cond), t, t_cond)
106-
end
126+
end
127+
128+
"""
129+
logratio_proposal_density(proposal, state, candidate)
130+
131+
Compute the log-ratio of the proposal densities in the Metropolis-Hastings algorithm.
132+
133+
The log-ratio of the proposal densities is defined as
134+
```math
135+
\\log \\frac{g(x | x')}{g(x' | x)},
136+
```
137+
where ``x`` is the current state, ``x'`` is the proposed candidate for the next state,
138+
and ``g(y' | y)`` is the conditional probability of proposing state ``y'`` given state
139+
``y`` (proposal density).
140+
"""
141+
function logratio_proposal_density(proposal::Proposal, state, candidate)
142+
return q(proposal, state, candidate) - q(proposal, candidate, state)
143+
end
144+
145+
# ratio is always 0 for symmetric proposals
146+
logratio_proposal_density(::RandomWalkProposal{true}, state, candidate) = 0
147+
logratio_proposal_density(::StaticProposal{true}, state, candidate) = 0
148+
149+
# type stable implementation for `NamedTuple`s
150+
function logratio_proposal_density(
151+
proposals::NamedTuple{names}, states::NamedTuple, candidates::NamedTuple
152+
) where {names}
153+
if @generated
154+
args = map(names) do name
155+
:(logratio_proposal_density(
156+
proposals[$(QuoteNode(name))],
157+
states[$(QuoteNode(name))],
158+
candidates[$(QuoteNode(name))],
159+
))
160+
end
161+
return :(+($(args...)))
162+
else
163+
return sum(names) do name
164+
return logratio_proposal_density(
165+
proposals[name], states[name], candidates[name]
166+
)
167+
end
168+
end
169+
end
170+
171+
# use recursion for `Tuple`s to ensure type stability
172+
logratio_proposal_density(proposals::Tuple{}, states::Tuple, candidates::Tuple) = 0
173+
function logratio_proposal_density(
174+
proposals::Tuple{<:Proposal}, states::Tuple, candidates::Tuple
175+
)
176+
return logratio_proposal_density(first(proposals), first(states), first(candidates))
177+
end
178+
function logratio_proposal_density(proposals::Tuple, states::Tuple, candidates::Tuple)
179+
valfirst = logratio_proposal_density(first(proposals), first(states), first(candidates))
180+
valtail = logratio_proposal_density(
181+
Base.tail(proposals), Base.tail(states), Base.tail(candidates)
182+
)
183+
return valfirst + valtail
184+
end
185+
186+
# fallback for general iterators (arrays etc.) - possibly not type stable!
187+
function logratio_proposal_density(proposals, states, candidates)
188+
return sum(zip(proposals, states, candidates)) do (proposal, state, candidate)
189+
return logratio_proposal_density(proposal, state, candidate)
190+
end
191+
end

0 commit comments

Comments
 (0)