Skip to content

Commit 1911b9d

Browse files
authored
Generally support container of proposals (#58)
1 parent 941c046 commit 1911b9d

File tree

5 files changed

+119
-147
lines changed

5 files changed

+119
-147
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.6.0"
3+
version = "0.6.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/MALA.jl

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,48 +19,60 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
1919
gradient::G
2020
end
2121

22-
transition(::MALA, model, params) = GradientTransition(model, params)
23-
24-
# Store the new draw, its log density and its gradient
25-
GradientTransition(model::DensityModel, params) = GradientTransition(params, logdensity_and_gradient(model, params)...)
22+
logdensity(model::DensityModel, t::GradientTransition) = t.lp
2623

2724
propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
25+
function transition(sampler::MALA, model::DensityModel, params)
26+
return GradientTransition(params, logdensity_and_gradient(model, params)...)
27+
end
2828

29-
function propose(
29+
function AbstractMCMC.step(
3030
rng::Random.AbstractRNG,
31-
spl::MALA{<:Proposal},
3231
model::DensityModel,
33-
params_prev::GradientTransition
34-
)
35-
proposal = propose(rng, spl.proposal(params_prev.gradient), model, params_prev.params)
36-
return GradientTransition(model, proposal)
37-
end
38-
39-
40-
function q(
41-
spl::MALA{<:Proposal},
42-
t::GradientTransition,
43-
t_cond::GradientTransition
44-
)
45-
return q(spl.proposal(-t_cond.gradient), t.params, t_cond.params)
46-
end
47-
48-
function logratio_proposal_density(
49-
sampler::MALA{<:Proposal}, state::GradientTransition, candidate::GradientTransition
32+
sampler::MALA,
33+
transition_prev::GradientTransition;
34+
kwargs...
5035
)
51-
return q(sampler, state, candidate) - q(sampler, candidate, state)
36+
# Extract value and gradient of the log density of the current state.
37+
state = transition_prev.params
38+
logdensity_state = transition_prev.lp
39+
gradient_logdensity_state = transition_prev.gradient
40+
41+
# Generate a new proposal.
42+
proposal = sampler.proposal
43+
candidate = propose(rng, proposal(gradient_logdensity_state), model, state)
44+
45+
# Compute both the value of the log density and its gradient
46+
logdensity_candidate, gradient_logdensity_candidate = logdensity_and_gradient(
47+
model, candidate
48+
)
49+
50+
# Compute the log ratio of proposal densities.
51+
logratio_proposal_density = q(
52+
proposal(-gradient_logdensity_candidate), state, candidate
53+
) - q(proposal(-gradient_logdensity_state), candidate, state)
54+
55+
# Compute the log acceptance probability.
56+
logα = logdensity_candidate - logdensity_state + logratio_proposal_density
57+
58+
# Decide whether to return the previous params or the new one.
59+
transition = if -Random.randexp(rng) < logα
60+
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate)
61+
else
62+
transition_prev
63+
end
64+
65+
return transition, transition
5266
end
5367

5468
"""
5569
logdensity_and_gradient(model::DensityModel, params)
5670
57-
Efficiently returns the value and gradient of the model
71+
Return the value and gradient of the log density of the parameters `params` for the `model`.
5872
"""
5973
function logdensity_and_gradient(model::DensityModel, params)
6074
res = GradientResult(params)
6175
gradient!(res, model.logdensity, params)
62-
return (value(res), gradient(res))
76+
return value(res), gradient(res)
6377
end
6478

65-
66-
logdensity(model::DensityModel, t::GradientTransition) = t.lp

src/emcee.jl

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,8 @@ struct Ensemble{D} <: MHSampler
33
proposal::D
44
end
55

6-
# Define the first sampling step.
7-
# Return a 2-tuple consisting of the initial sample and the initial state.
8-
# In this case they are identical.
9-
function AbstractMCMC.step(
10-
rng::Random.AbstractRNG,
11-
model::DensityModel,
12-
spl::Ensemble;
13-
init_params = nothing,
14-
kwargs...,
15-
)
16-
if init_params === nothing
17-
transitions = propose(rng, spl, model)
18-
else
19-
transitions = [Transition(model, x) for x in init_params]
20-
end
21-
22-
return transitions, transitions
6+
function transition(sampler::Ensemble, model::DensityModel, params)
7+
return [Transition(model, x) for x in params]
238
end
249

2510
# Define the other sampling steps.

src/mh-core.jl

Lines changed: 27 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -48,115 +48,38 @@ end
4848
StaticMH(d) = MetropolisHastings(StaticProposal(d))
4949
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
5050

51-
# default function without RNG
52-
propose(spl::MetropolisHastings, args...) = propose(Random.GLOBAL_RNG, spl, args...)
53-
54-
# Propose from a vector of proposals
55-
function propose(
56-
rng::Random.AbstractRNG,
57-
spl::MetropolisHastings{<:AbstractArray},
58-
model::DensityModel
59-
)
60-
proposal = map(p -> propose(rng, p, model), spl.proposal)
61-
return Transition(model, proposal)
62-
end
63-
64-
function propose(
65-
rng::Random.AbstractRNG,
66-
spl::MetropolisHastings{<:AbstractArray},
67-
model::DensityModel,
68-
params_prev::Transition
69-
)
70-
proposal = map(spl.proposal, params_prev.params) do p, params
71-
propose(rng, p, model, params)
72-
end
73-
return Transition(model, proposal)
74-
end
75-
76-
# Make a proposal from one Proposal struct.
77-
function propose(
78-
rng::Random.AbstractRNG,
79-
spl::MetropolisHastings{<:Proposal},
80-
model::DensityModel
81-
)
82-
proposal = propose(rng, spl.proposal, model)
83-
return Transition(model, proposal)
84-
end
85-
86-
function propose(
87-
rng::Random.AbstractRNG,
88-
spl::MetropolisHastings{<:Proposal},
89-
model::DensityModel,
90-
params_prev::Transition
91-
)
92-
proposal = propose(rng, spl.proposal, model, params_prev.params)
93-
return Transition(model, proposal)
94-
end
95-
96-
# Make a proposal from a NamedTuple of Proposal.
97-
function propose(
98-
rng::Random.AbstractRNG,
99-
spl::MetropolisHastings{<:NamedTuple},
100-
model::DensityModel
101-
)
102-
proposal = _propose(rng, spl.proposal, model)
103-
return Transition(model, proposal)
51+
function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModel)
52+
return propose(rng, sampler.proposal, model)
10453
end
105-
10654
function propose(
10755
rng::Random.AbstractRNG,
108-
spl::MetropolisHastings{<:NamedTuple},
56+
sampler::MHSampler,
10957
model::DensityModel,
110-
params_prev::Transition
58+
transition_prev::Transition,
11159
)
112-
proposal = _propose(rng, spl.proposal, model, params_prev.params)
113-
return Transition(model, proposal)
60+
return propose(rng, sampler.proposal, model, transition_prev.params)
11461
end
11562

116-
@generated function _propose(
117-
rng::Random.AbstractRNG,
118-
proposal::NamedTuple{names},
119-
model::DensityModel
120-
) where {names}
121-
isempty(names) && return :(NamedTuple())
122-
expr = Expr(:tuple)
123-
expr.args = Any[:($name = propose(rng, proposal.$name, model)) for name in names]
124-
return expr
63+
function transition(sampler::MHSampler, model::DensityModel, params)
64+
logdensity = AdvancedMH.logdensity(model, params)
65+
return transition(sampler, model, params, logdensity)
12566
end
126-
127-
@generated function _propose(
128-
rng::Random.AbstractRNG,
129-
proposal::NamedTuple{names},
130-
model::DensityModel,
131-
params_prev::NamedTuple
132-
) where {names}
133-
isempty(names) && return :(NamedTuple())
134-
expr = Expr(:tuple)
135-
expr.args = Any[
136-
:($name = propose(rng, proposal.$name, model, params_prev.$name)) for name in names
137-
]
138-
return expr
67+
function transition(sampler::MHSampler, model::DensityModel, params, logdensity::Real)
68+
return Transition(params, logdensity)
13969
end
14070

141-
transition(sampler, model, params) = transition(model, params)
142-
transition(model, params) = Transition(model, params)
143-
14471
# Define the first sampling step.
14572
# Return a 2-tuple consisting of the initial sample and the initial state.
14673
# In this case they are identical.
14774
function AbstractMCMC.step(
14875
rng::Random.AbstractRNG,
14976
model::DensityModel,
150-
spl::MHSampler;
77+
sampler::MHSampler;
15178
init_params=nothing,
15279
kwargs...
15380
)
154-
if init_params === nothing
155-
transition = propose(rng, spl, model)
156-
else
157-
transition = AdvancedMH.transition(spl, model, init_params)
158-
end
159-
81+
params = init_params === nothing ? propose(rng, sampler, model) : init_params
82+
transition = AdvancedMH.transition(sampler, model, params)
16083
return transition, transition
16184
end
16285

@@ -167,27 +90,30 @@ end
16790
function AbstractMCMC.step(
16891
rng::Random.AbstractRNG,
16992
model::DensityModel,
170-
spl::MHSampler,
171-
params_prev::AbstractTransition;
93+
sampler::MHSampler,
94+
transition_prev::AbstractTransition;
17295
kwargs...
17396
)
17497
# Generate a new proposal.
175-
params = propose(rng, spl, model, params_prev)
98+
candidate = propose(rng, sampler, model, transition_prev)
17699

177-
# Calculate the log acceptance probability.
178-
logα = logdensity(model, params) - logdensity(model, params_prev) +
179-
logratio_proposal_density(spl, params_prev, params)
100+
# Calculate the log acceptance probability and the log density of the candidate.
101+
logdensity_candidate = logdensity(model, candidate)
102+
logα = logdensity_candidate - logdensity(model, transition_prev) +
103+
logratio_proposal_density(sampler, transition_prev, candidate)
180104

181105
# Decide whether to return the previous params or the new one.
182-
if -Random.randexp(rng) < logα
183-
return params, params
106+
transition = if -Random.randexp(rng) < logα
107+
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate)
184108
else
185-
return params_prev, params_prev
109+
transition_prev
186110
end
111+
112+
return transition, transition
187113
end
188114

189115
function logratio_proposal_density(
190-
sampler::MetropolisHastings, params_prev::Transition, params::Transition
116+
sampler::MetropolisHastings, transition_prev::AbstractTransition, candidate
191117
)
192-
return logratio_proposal_density(sampler.proposal, params_prev.params, params.params)
118+
return logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
193119
end

src/proposal.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,55 @@ function q(
125125
return q(proposal(t_cond), t, t_cond)
126126
end
127127

128+
####################
129+
# Multiple proposals
130+
####################
131+
132+
function propose(
133+
rng::Random.AbstractRNG,
134+
proposals::AbstractArray{<:Proposal},
135+
model::DensityModel,
136+
)
137+
return map(proposals) do proposal
138+
return propose(rng, proposal, model)
139+
end
140+
end
141+
function propose(
142+
rng::Random.AbstractRNG,
143+
proposals::AbstractArray{<:Proposal},
144+
model::DensityModel,
145+
ts,
146+
)
147+
return map(proposals, ts) do proposal, t
148+
return propose(rng, proposal, model, t)
149+
end
150+
end
151+
152+
@generated function propose(
153+
rng::Random.AbstractRNG,
154+
proposals::NamedTuple{names},
155+
model::DensityModel,
156+
) where {names}
157+
isempty(names) && return :(NamedTuple())
158+
expr = Expr(:tuple)
159+
expr.args = Any[:($name = propose(rng, proposals.$name, model)) for name in names]
160+
return expr
161+
end
162+
163+
@generated function propose(
164+
rng::Random.AbstractRNG,
165+
proposals::NamedTuple{names},
166+
model::DensityModel,
167+
ts,
168+
) where {names}
169+
isempty(names) && return :(NamedTuple())
170+
expr = Expr(:tuple)
171+
expr.args = Any[
172+
:($name = propose(rng, proposals.$name, model, ts.$name)) for name in names
173+
]
174+
return expr
175+
end
176+
128177
"""
129178
logratio_proposal_density(proposal, state, candidate)
130179

0 commit comments

Comments
 (0)