-
Notifications
You must be signed in to change notification settings - Fork 22
Implementation of Robust Adaptive Metropolis #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
85ec534
de519a4
40ebb7e
2dec18a
045f8c5
755a180
5c1c6f5
cddf8d1
29c9078
78a5f51
652a227
5eaff52
d8688fa
f5fc301
56ec717
da431b4
f2889a0
9247281
4764120
11f3b64
df4feb1
f784492
45820d2
7405a19
5dce265
5ee44e3
37a2189
5193119
d4a144e
6295e78
6f8fda4
5815a9b
1b38ca6
f426d0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,10 @@ | ||
[deps] | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" | ||
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
|
||
[compat] | ||
Documenter = "1" |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,205 @@ | ||||||
module RobustAdaptiveMetropolis | ||||||
devmotion marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
||||||
using Random, LogDensityProblems, LinearAlgebra, AbstractMCMC | ||||||
using DocStringExtensions: FIELDS | ||||||
|
||||||
using AdvancedMH: AdvancedMH | ||||||
|
||||||
export RAM | ||||||
|
||||||
# TODO: Should we generalise this arbitrary symmetric proposals? | ||||||
""" | ||||||
RAM | ||||||
Robust Adaptive Metropolis-Hastings (RAM). | ||||||
This is a simple implementation of the RAM algorithm described in [^VIH12]. | ||||||
# Fields | ||||||
$(FIELDS) | ||||||
# Examples | ||||||
The following demonstrates how to implement a simple Gaussian model and sample from it using the RAM algorithm. | ||||||
```jldoctest | ||||||
julia> using AdvancedMH, Random, Distributions, MCMCChains, LogDensityProblems, LinearAlgebra | ||||||
julia> # Define a Gaussian with zero mean and some covariance. | ||||||
struct Gaussian{A} | ||||||
Σ::A | ||||||
end | ||||||
julia> # Implement the LogDensityProblems interface. | ||||||
LogDensityProblems.dimension(model::Gaussian) = size(model.Σ, 1) | ||||||
julia> function LogDensityProblems.logdensity(model::Gaussian, x) | ||||||
d = LogDensityProblems.dimension(model) | ||||||
return logpdf(MvNormal(zeros(d),model.Σ), x) | ||||||
end | ||||||
julia> LogDensityProblems.capabilities(::Gaussian) = LogDensityProblems.LogDensityOrder{0}() | ||||||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
julia> # Construct the model. We'll use a correlation of 0.5. | ||||||
model = Gaussian([1.0 0.5; 0.5 1.0]); | ||||||
julia> # Number of samples we want in the resulting chain. | ||||||
num_samples = 10_000; | ||||||
julia> # Number of warmup steps, i.e. the number of steps to adapt the covariance of the proposal. | ||||||
# Note that these are not included in the resulting chain, as `discard_initial=num_warmup` | ||||||
# by default in the `sample` call. To include them, pass `discard_initial=0` to `sample`. | ||||||
num_warmup = 10_000; | ||||||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
julia> # Set the seed so get some consistency. | ||||||
Random.seed!(1234); | ||||||
devmotion marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
julia> # Sample! | ||||||
chain = sample(model, RAM(), 10_000; chain_type=Chains, num_warmup=10_000, progress=false, initial_params=zeros(2)); | ||||||
julia> norm(cov(Array(chain)) - [1.0 0.5; 0.5 1.0]) < 0.2 | ||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
true | ||||||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
``` | ||||||
# References | ||||||
[^VIH12]: Vihola (2012) Robust adaptive Metropolis algorithm with coerced acceptance rate, Statistics and computing. | ||||||
""" | ||||||
Base.@kwdef struct RAM{T,A<:Union{Nothing,AbstractMatrix{T}}} <: AdvancedMH.MHSampler | ||||||
devmotion marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
"target acceptance rate" | ||||||
α::T=0.234 | ||||||
"negative exponent of the adaptation decay rate" | ||||||
γ::T=0.6 | ||||||
"initial lower-triangular Cholesky factor" | ||||||
S::A=nothing | ||||||
"lower bound on eigenvalues of the adapted covariance matrix" | ||||||
eigenvalue_lower_bound::T=0.0 | ||||||
"upper bound on eigenvalues of the adapted covariance matrix" | ||||||
eigenvalue_upper_bound::T=Inf | ||||||
end | ||||||
|
||||||
# TODO: Should we record anything like the acceptance rates? | ||||||
struct RAMState{T1,L,A,T2,T3} | ||||||
x::T1 | ||||||
logprob::L | ||||||
S::A | ||||||
logα::T2 | ||||||
η::T3 | ||||||
iteration::Int | ||||||
isaccept::Bool | ||||||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
end | ||||||
|
||||||
AbstractMCMC.getparams(state::RAMState) = state.x | ||||||
AbstractMCMC.setparams!!(state::RAMState, x) = RAMState(x, state.logprob, state.S, state.logα, state.η, state.iteration, state.isaccept) | ||||||
|
||||||
function step_inner( | ||||||
rng::Random.AbstractRNG, | ||||||
model::AbstractMCMC.LogDensityModel, | ||||||
sampler::RAM, | ||||||
state::RAMState | ||||||
) | ||||||
# This is the initial state. | ||||||
f = model.logdensity | ||||||
d = LogDensityProblems.dimension(f) | ||||||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
# Sample the proposal. | ||||||
x = state.x | ||||||
U = randn(rng, d) | ||||||
x_new = x + state.S * U | ||||||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
||||||
# Compute the acceptance probability. | ||||||
lp = state.logprob | ||||||
lp_new = LogDensityProblems.logdensity(f, x_new) | ||||||
logα = min(lp_new - lp, zero(lp)) # `min` because we'll use this for updating | ||||||
|
||||||
# TODO: use `randexp` instead. | ||||||
isaccept = log(rand(rng)) < logα | ||||||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
||||||
return x_new, lp_new, U, logα, isaccept | ||||||
end | ||||||
|
||||||
function adapt(sampler::RAM, state::RAMState, logα::Real, U::AbstractVector) | ||||||
# Update ` | ||||||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
Δα = exp(logα) - sampler.α | ||||||
S = state.S | ||||||
# TODO: Make this configurable by defining a more general path. | ||||||
η = state.iteration^(-sampler.γ) | ||||||
ΔS = η * abs(Δα) * S * U / norm(U) | ||||||
devmotion marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
# TODO: Maybe do in-place and then have the user extract it with a callback if they really want it. | ||||||
S_new = if sign(Δα) == 1 | ||||||
# One rank update. | ||||||
LinearAlgebra.lowrankupdate(Cholesky(S), ΔS).L | ||||||
else | ||||||
# One rank downdate. | ||||||
LinearAlgebra.lowrankdowndate(Cholesky(S), ΔS).L | ||||||
end | ||||||
return S_new, η | ||||||
end | ||||||
|
||||||
function AbstractMCMC.step( | ||||||
rng::Random.AbstractRNG, | ||||||
model::AbstractMCMC.LogDensityModel, | ||||||
sampler::RAM; | ||||||
initial_params=nothing, | ||||||
kwargs... | ||||||
) | ||||||
# This is the initial state. | ||||||
f = model.logdensity | ||||||
d = LogDensityProblems.dimension(f) | ||||||
|
||||||
# Initial parameter state. | ||||||
x = initial_params === nothing ? rand(rng, d) : initial_params | ||||||
|
x = initial_params === nothing ? rand(rng, d) : initial_params | |
x = initial_params === nothing ? rand(rng, eltype(sampler.γ), d) : initial_params |
By the way, rand(rng, d)
doesn't seem a good choice in general? The algorithm requires that you start with a point in the support of the target distribution and it's not clear if the target density is zero for this point. I wonder if it requires something like https://github.com/TuringLang/EllipticalSliceSampling.jl/blob/3296ae3566d329207875216837e65eeec3b809b2/src/interface.jl#L20-L29 in EllipticalSliceSampling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this is using a RWMH as the main kernel, so IMO we're already assuming unconstrained support for this to be valid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And more generally, happy to deal with better initialisation, buuuuut prefer to do this in a separate PR as I'm imagining this will require some discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so IMO we're already assuming unconstrained support
But why prefer rand
over randn
in that case?
But I agree, probably this question (rand
/randn
/dedicated API) should be addressed separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But why prefer rand over randn in that case?
Just did it quickly because I have some vague memory that it's generally preferred to do initialisation in a box near 0 for most of the linking transformations (believe this is the moticvation behind SampleFromUniform
in DPPL, though I think it technically initialses from a cube centered on 0?).
But it's w/e to me here; we need a better way in general for this, so I'll just change it to randn
👍
Uh oh!
There was an error while loading. Please reload this page.