Skip to content

Commit e277e86

Browse files
authored
Merge pull request #10 from TuringLang/csp/ahm2
Add static MH and general code improvements
2 parents c356cf8 + 6669ff4 commit e277e86

File tree

7 files changed

+202
-93
lines changed

7 files changed

+202
-93
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.1.0"
3+
version = "0.2.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

README.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ AdvancedMH.jl currently provides a robust implementation of random walk Metropol
44

55
Further development aims to provide a suite of adaptive Metropolis-Hastings implementations.
66

7+
Currently there are two sampler types. The first is `RWMH`, which represents random-walk MH sampling, and the second is `StaticMH`, which draws proposals
8+
from only a prior distribution without incrementing the previous sample.
9+
710
## Usage
811

912
AdvancedMH works by accepting some log density function which is used to construct a `DensityModel`. The `DensityModel` is then used in a `sample` call.
@@ -25,7 +28,7 @@ density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf
2528
model = DensityModel(density)
2629

2730
# Set up our sampler with initial parameters.
28-
spl = MetropolisHastings([0.0, 0.0])
31+
spl = RWMH([0.0, 0.0])
2932

3033
# Sample from the posterior.
3134
chain = sample(model, spl, 100000; param_names=["μ", "σ"])
@@ -68,5 +71,16 @@ Custom proposal distributions can be specified by passing a distribution to `Met
6871

6972
```julia
7073
# Set up our sampler with initial parameters.
71-
spl = MetropolisHastings([0.0, 0.0], MvNormal(2, 0.5))
74+
spl1 = RWMH([0.0, 0.0], MvNormal(2, 0.5))
75+
spl2 = StaticMH([0.0, 0.0], MvNormal(2, 0.5))
76+
```
77+
78+
## Multithreaded sampling
79+
80+
AdvancedMH.jl implements the interface of [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl/), which means you get multiple chain sampling
81+
in parallel for free:
82+
83+
```julia
84+
# Sample 4 chains from the posterior.
85+
chain = psample(model, RWMH(init_params), 100000, 4; param_names=["μ","σ"])
7286
```

src/AdvancedMH.jl

Lines changed: 16 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,11 @@ import MCMCChains: Chains
1111
import AbstractMCMC: step!, AbstractSampler, AbstractTransition, transition_type, bundle_samples
1212

1313
# Exports
14-
export MetropolisHastings, DensityModel, sample
14+
export MetropolisHastings, DensityModel, sample, psample, RWMH, StaticMH
1515

16-
"""
17-
MetropolisHastings{T, F<:Function}
18-
19-
Fields:
20-
21-
- `init_θ` is the vector form of the parameters needed for the likelihood function.
22-
- `proposal` is a function that dynamically constructs a conditional distribution.
23-
24-
Example:
25-
26-
```julia
27-
MetropolisHastings([0.0, 0.0], x -> MvNormal(x, 1.0))
28-
````
29-
"""
30-
struct MetropolisHastings{T, D} <: AbstractSampler
31-
init_θ :: T
32-
proposal :: D
33-
end
34-
35-
# Default constructors.
36-
MetropolisHastings(init_θ::Real) = MetropolisHastings(init_θ, Normal(0,1))
37-
MetropolisHastings(init_θ::Vector{<:Real}) = MetropolisHastings(init_θ, MvNormal(length(init_θ),1))
16+
# Abstract type for MH-style samplers.
17+
abstract type Metropolis <: AbstractSampler end
18+
abstract type ProposalStyle end
3819

3920
# Define a model type. Stores the log density function and the data to
4021
# evaluate the log density on.
@@ -65,74 +46,17 @@ end
6546
Transition(model::M, θ::T) where {M<:DensityModel, T} = Transition(θ, ℓπ(model, θ))
6647

6748
# Tell the interface what transition type we would like to use.
68-
transition_type(model::DensityModel, spl::MetropolisHastings) = typeof(Transition(spl.init_θ, ℓπ(model, spl.init_θ)))
69-
70-
# Define a function that makes a basic proposal depending on a univariate
71-
# parameterization or a multivariate parameterization.
72-
propose(spl::MetropolisHastings, model::DensityModel, θ::Real) =
73-
Transition(model, θ + rand(spl.proposal))
74-
propose(spl::MetropolisHastings, model::DensityModel, θ::Vector{<:Real}) =
75-
Transition(model, θ + rand(spl.proposal))
76-
propose(spl::MetropolisHastings, model::DensityModel, t::Transition) = propose(spl, model, t.θ)
77-
78-
"""
79-
q(θ::Real, dist::Sampleable)
80-
q(θ::Vector{<:Real}, dist::Sampleable)
81-
q(t1::Transition, dist::Sampleable)
82-
83-
Calculates the probability `q(θ | θcond)`, using the proposal distribution `spl.proposal`.
84-
"""
85-
q(spl::MetropolisHastings, θ::Real, θcond::Real) = logpdf(spl.proposal, θ - θcond)
86-
q(spl::MetropolisHastings, θ::Vector{<:Real}, θcond::Vector{<:Real}) = logpdf(spl.proposal, θ - θcond)
87-
q(spl::MetropolisHastings, t1::Transition, t2::Transition) = q(spl, t1.θ, t2.θ)
49+
transition_type(model::DensityModel, spl::Metropolis) = typeof(Transition(spl.init_θ, ℓπ(model, spl.init_θ)))
8850

8951
# Calculate the density of the model given some parameterization.
9052
ℓπ(model::DensityModel, θ::T) where T = model.ℓπ(θ)
9153
ℓπ(model::DensityModel, t::Transition) = t.lp
9254

93-
# Define the first step! function, which is called at the
94-
# beginning of sampling. Return the initial parameter used
95-
# to define the sampler.
96-
function step!(
97-
rng::AbstractRNG,
98-
model::DensityModel,
99-
spl::MetropolisHastings,
100-
N::Integer;
101-
kwargs...
102-
)
103-
return Transition(model, spl.init_θ)
104-
end
105-
106-
# Define the other step functions. Returns a Transition containing
107-
# either a new proposal (if accepted) or the previous proposal
108-
# (if not accepted).
109-
function step!(
110-
rng::AbstractRNG,
111-
model::DensityModel,
112-
spl::MetropolisHastings,
113-
::Integer,
114-
θ_prev::Transition;
115-
kwargs...
116-
)
117-
# Generate a new proposal.
118-
θ = propose(spl, model, θ_prev)
119-
120-
# Calculate the log acceptance probability.
121-
α = ℓπ(model, θ) - ℓπ(model, θ_prev) + q(spl, θ_prev, θ) - q(spl, θ, θ_prev)
122-
123-
# Decide whether to return the previous θ or the new one.
124-
if log(rand(rng)) < min(α, 0.0)
125-
return θ
126-
else
127-
return θ_prev
128-
end
129-
end
130-
13155
# A basic chains constructor that works with the Transition struct we defined.
13256
function bundle_samples(
13357
rng::AbstractRNG,
13458
::DensityModel,
135-
s::MetropolisHastings,
59+
s::Metropolis,
13660
N::Integer,
13761
ts::Vector{T};
13862
param_names=missing,
@@ -143,7 +67,10 @@ function bundle_samples(
14367

14468
# Check if we received any parameter names.
14569
if ismissing(param_names)
146-
param_names = ["Parameter $i" for i in 1:(length(first(vals))-1)]
70+
param_names = ["Parameter $i" for i in 1:length(s.init_θ)]
71+
else
72+
# Deepcopy to be thread safe.
73+
param_names = deepcopy(param_names)
14774
end
14875

14976
# Add the log density field to the parameter names.
@@ -153,4 +80,9 @@ function bundle_samples(
15380
return Chains(vals, param_names, (internals=["lp"],))
15481
end
15582

156-
end # module AdvancedMH
83+
# Include inference methods.
84+
include("mh-core.jl")
85+
include("rwmh.jl")
86+
include("staticmh.jl")
87+
88+
end # module AdvancedMH

src/mh-core.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
MetropolisHastings{P<:ProposalStyle, T, D}
3+
4+
Fields:
5+
6+
- `init_θ` is the vector form of the parameters needed for the likelihood function.
7+
- `proposal` is a function that dynamically constructs a conditional distribution.
8+
9+
Example:
10+
11+
```julia
12+
MetropolisHastings([0.0, 0.0], x -> MvNormal(x, 1.0))
13+
````
14+
"""
15+
struct MetropolisHastings{P<:ProposalStyle, T, D} <: Metropolis
16+
proposal_type :: P
17+
init_θ :: T
18+
proposal :: D
19+
end
20+
21+
# Default constructors.
22+
MetropolisHastings(init_θ::Real) = MetropolisHastings(init_θ, Normal(0,1))
23+
MetropolisHastings(init_θ::Vector{<:Real}) = MetropolisHastings(init_θ, MvNormal(length(init_θ),1))
24+
25+
"""
26+
propose(spl::MetropolisHastings, model::DensityModel, t::Transition)
27+
28+
Generates a new parameter proposal conditional on the model, the sampler, and the previous
29+
sample.
30+
"""
31+
@inline propose(spl::MetropolisHastings, model::DensityModel, t::Transition) = propose(spl, model, t.θ)
32+
33+
"""
34+
q(spl::MetropolisHastings, t1::Transition, t2::Transition)
35+
36+
Calculates the probability `q(θ | θcond)`, using the proposal distribution `spl.proposal`.
37+
"""
38+
@inline q(spl::MetropolisHastings, t1::Transition, t2::Transition) = q(spl, t1.θ, t2.θ)
39+
40+
# Define the first step! function, which is called at the
41+
# beginning of sampling. Return the initial parameter used
42+
# to define the sampler.
43+
function step!(
44+
rng::AbstractRNG,
45+
model::DensityModel,
46+
spl::MetropolisHastings,
47+
N::Integer;
48+
kwargs...
49+
)
50+
return Transition(model, spl.init_θ)
51+
end
52+
53+
# Define the other step functions. Returns a Transition containing
54+
# either a new proposal (if accepted) or the previous proposal
55+
# (if not accepted).
56+
function step!(
57+
rng::AbstractRNG,
58+
model::DensityModel,
59+
spl::MetropolisHastings,
60+
::Integer,
61+
θ_prev::Transition;
62+
kwargs...
63+
)
64+
# Generate a new proposal.
65+
θ = propose(spl, model, θ_prev)
66+
67+
# Calculate the log acceptance probability.
68+
α = ℓπ(model, θ) - ℓπ(model, θ_prev) + q(spl, θ_prev, θ) - q(spl, θ, θ_prev)
69+
70+
# Decide whether to return the previous θ or the new one.
71+
if log(rand(rng)) < min(α, 0.0)
72+
return θ
73+
else
74+
return θ_prev
75+
end
76+
end

src/rwmh.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
struct RandomWalk <: ProposalStyle end
2+
3+
"""
4+
RWMH(init_theta::Real, proposal = Normal(init_theta, 1))
5+
RWMH(init_theta::Vector{Real}, proposal = MvNormal(init_theta, 1))
6+
7+
Random walk Metropolis-Hastings.
8+
9+
Fields:
10+
11+
- `init_θ` is the vector form of the parameters needed for the likelihood function.
12+
- `proposal` is a function that dynamically constructs a conditional distribution.
13+
14+
Example:
15+
16+
```julia
17+
RWMH([0.0, 0.0], x -> MvNormal(x, 1.0))
18+
````
19+
"""
20+
RWMH(init_theta::Real, proposal = Normal(init_theta, 1)) = MetropolisHastings(RandomWalk(), init_theta, proposal)
21+
RWMH(init_theta::Vector{<:Real}, proposal = MvNormal(init_theta, 1)) = MetropolisHastings(RandomWalk(), init_theta, proposal)
22+
23+
# Define a function that makes a basic proposal depending on a univariate
24+
# parameterization or a multivariate parameterization.
25+
propose(spl::MetropolisHastings{RandomWalk}, model::DensityModel, θ::Real) = Transition(model, θ + rand(spl.proposal))
26+
propose(spl::MetropolisHastings{RandomWalk}, model::DensityModel, θ::Vector{<:Real}) = Transition(model, θ + rand(spl.proposal))
27+
28+
"""
29+
q(θ::Real, dist::Sampleable)
30+
q(θ::Vector{<:Real}, dist::Sampleable)
31+
q(t1::Transition, dist::Sampleable)
32+
33+
Calculates the probability `q(θ | θcond)`, using the proposal distribution `spl.proposal`.
34+
"""
35+
@inline q(spl::MetropolisHastings{RandomWalk}, θ::Real, θcond::Real) = logpdf(spl.proposal, θ - θcond)
36+
@inline q(spl::MetropolisHastings{RandomWalk}, θ::Vector{<:Real}, θcond::Vector{<:Real}) = logpdf(spl.proposal, θ - θcond)

src/staticmh.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
struct Static <: ProposalStyle end
2+
3+
"""
4+
StaticMH(init_theta::Real, proposal = Normal(init_theta, 1))
5+
StaticMH(init_theta::Vector{Real}, proposal = MvNormal(init_theta, 1))
6+
7+
Static Metropolis-Hastings. Proposes only from the prior distribution.
8+
9+
Fields:
10+
11+
- `init_θ` is the vector form of the parameters needed for the likelihood function.
12+
- `proposal` is a distribution.
13+
14+
Example:
15+
16+
```julia
17+
RWMH([0.0, 0.0], MvNormal(x, 1.0))
18+
````
19+
"""
20+
StaticMH(init_theta::Real, proposal = Normal(init_theta, 1)) = MetropolisHastings(Static(), init_theta, proposal)
21+
StaticMH(init_theta::Vector{<:Real}, proposal = MvNormal(init_theta, 1)) = MetropolisHastings(Static(), init_theta, proposal)
22+
23+
# Define a function that makes a basic proposal depending on a univariate
24+
# parameterization or a multivariate parameterization.
25+
propose(spl::MetropolisHastings{Static}, model::DensityModel, θ::Real) = Transition(model, rand(spl.proposal))
26+
propose(spl::MetropolisHastings{Static}, model::DensityModel, θ::Vector{<:Real}) = Transition(model, rand(spl.proposal))
27+
28+
"""
29+
q(θ::Real, dist::Sampleable)
30+
q(θ::Vector{<:Real}, dist::Sampleable)
31+
q(t1::Transition, dist::Sampleable)
32+
33+
Calculates the probability `q(θ | θcond)`, using the proposal distribution `spl.proposal`.
34+
"""
35+
q(spl::MetropolisHastings{Static}, θ::Real, θcond::Real) = logpdf(spl.proposal, θ)
36+
q(spl::MetropolisHastings{Static}, θ::Vector{<:Real}, θcond::Vector{<:Real}) = logpdf(spl.proposal, θ)
37+

test/runtests.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,26 @@ using Random
2020
model = DensityModel(density)
2121

2222
# Set up our sampler with initial parameters.
23-
spl = MetropolisHastings([0.0, 0.0])
23+
spl1 = RWMH([0.0, 0.0])
24+
spl2 = StaticMH([0.0, 0.0], MvNormal([0.0, 0.0], 1))
2425

25-
# Sample from the posterior.
26-
chain = sample(model, spl, 100000; param_names=["μ", "σ"])
26+
@testset "Inference" begin
2727

28-
# chn_mean ≈ dist_mean atol=atol_v
29-
@test mean(chain["μ"].value) 0.0 atol=0.1
30-
@test mean(chain["σ"].value) 1.0 atol=0.1
28+
# Sample from the posterior.
29+
chain1 = sample(model, spl1, 100000; param_names=["μ", "σ"])
30+
chain2 = sample(model, spl2, 100000; param_names=["μ", "σ"])
31+
32+
# chn_mean ≈ dist_mean atol=atol_v
33+
@test mean(chain1["μ"].value) 0.0 atol=0.1
34+
@test mean(chain1["σ"].value) 1.0 atol=0.1
35+
@test mean(chain2["μ"].value) 0.0 atol=0.1
36+
@test mean(chain2["σ"].value) 1.0 atol=0.1
37+
end
38+
39+
@testset "psample" begin
40+
chain1 = psample(model, spl1, 10000, 4; param_names=["μ", "σ"])
41+
@test mean(chain1["μ"].value) 0.0 atol=0.1
42+
@test mean(chain1["σ"].value) 1.0 atol=0.1
43+
end
3144
end
45+

0 commit comments

Comments
 (0)