Skip to content

Commit 728c7b5

Browse files
authored
Merge pull request #12 from TuringLang/csp/turing-connect
Make AdvancedMH better
2 parents e277e86 + 4206d91 commit 728c7b5

File tree

11 files changed

+485
-172
lines changed

11 files changed

+485
-172
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
Manifest.toml
2+
.vscode/settings.json

Project.toml

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.2.0"
3+
version = "0.3.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
98
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
109
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
10+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
11+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1112
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1213

13-
[extras]
14-
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
15-
1614
[compat]
17-
AbstractMCMC = "0.1"
18-
Distributions = "0.21"
19-
MCMCChains = "1, 0.4"
15+
AbstractMCMC = "0.3"
16+
Distributions = "0.20, 0.21, 0.22, 0.23"
2017
Reexport = "0.2"
18+
Requires = "1.0"
19+
StructArrays = "^0"
2120
julia = "1"
21+
22+
[extras]
23+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
24+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
25+
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
26+
27+
[targets]
28+
test = ["StructArrays", "MCMCChains"]

README.md

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ AdvancedMH.jl currently provides a robust implementation of random walk Metropol
44

55
Further development aims to provide a suite of adaptive Metropolis-Hastings implementations.
66

7-
Currently there are two sampler types. The first is `RWMH`, which represents random-walk MH sampling, and the second is `StaticMH`, which draws proposals
8-
from only a prior distribution without incrementing the previous sample.
7+
AdvancedMH works by allowing users to define composable `Proposal` structs in different formats.
98

109
## Usage
1110

12-
AdvancedMH works by accepting some log density function which is used to construct a `DensityModel`. The `DensityModel` is then used in a `sample` call.
11+
First, construct a `DensityModel`, which is a wrapper around the log density function for your inference problem. The `DensityModel` is then used in a `sample` call.
1312

1413
```julia
1514
# Import the package.
1615
using AdvancedMH
1716
using Distributions
17+
using MCMCChains
1818

1919
# Generate a set of data from the posterior we want to estimate.
2020
data = rand(Normal(0, 1), 30)
@@ -27,11 +27,11 @@ density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf
2727
# Construct a DensityModel.
2828
model = DensityModel(density)
2929

30-
# Set up our sampler with initial parameters.
31-
spl = RWMH([0.0, 0.0])
30+
# Set up our sampler with a joint multivariate Normal proposal.
31+
spl = RWMH(MvNormal(2,1))
3232

3333
# Sample from the posterior.
34-
chain = sample(model, spl, 100000; param_names=["μ", "σ"])
34+
chain = sample(model, spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
3535
```
3636

3737
Output:
@@ -46,41 +46,74 @@ Samples per chain = 100000
4646
internals = lp
4747
parameters = μ, σ
4848

49-
2-element Array{MCMCChains.ChainDataFrame,1}
49+
2-element Array{ChainDataFrame,1}
5050

5151
Summary Statistics
5252

53-
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │ r_hat │
54-
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │ Any │
55-
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┼─────────┤
56-
1 │ μ │ 0.08341880.241418 0.000763430.003410674693.041.00008
57-
2 │ σ │ 1.33116 0.1841110.000582210.002587784965.831.00001
53+
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │ r_hat │
54+
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │ Any │
55+
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┼─────────┤
56+
1 │ μ │ 0.1561520.19963 0.0006312850.003230333911.731.00009
57+
2 │ σ │ 1.07493 0.1501110.0004746930.002403173707.731.00027
5858

5959
Quantiles
6060

61-
│ Row │ parameters │ 2.5%25.0%50.0%75.0%97.5%
62-
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
63-
├─────┼────────────┼───────────┼────────────┼───────────┼──────────┼──────────┤
64-
1 │ μ │ -0.393769-0.07711340.08016880.2411620.564331
65-
2 │ σ │ 1.036851.20441.309921.436091.75745
61+
│ Row │ parameters │ 2.5%25.0%50.0%75.0%97.5%
62+
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
63+
├─────┼────────────┼──────────┼───────────┼──────────┼──────────┼──────────┤
64+
1 │ μ │ -0.233610.02970060.1591390.2834930.558694
65+
2 │ σ │ 0.8282880.9726821.058041.161551.41349
66+
6667
```
6768

68-
## Custom proposals
69+
## Proposals
6970

70-
Custom proposal distributions can be specified by passing a distribution to `MetropolisHastings`:
71+
AdvancedMH offers various methods of defining your inference problem. Behind the scenes, a `MetropolisHastings` sampler simply holds
72+
some set of `Proposal` structs. AdvancedMH will return posterior samples in the "shape" of the proposal provided -- currently
73+
supported methods are `Array{Proposal}`, `Proposal`, and `NamedTuple{Proposal}`. For example, proposals can be created as:
7174

7275
```julia
73-
# Set up our sampler with initial parameters.
74-
spl1 = RWMH([0.0, 0.0], MvNormal(2, 0.5))
75-
spl2 = StaticMH([0.0, 0.0], MvNormal(2, 0.5))
76+
# Provide a univariate proposal.
77+
m1 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
78+
p1 = Proposal(Static(), Normal(0,1))
79+
c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple})
80+
81+
# Draw from a vector of distributions.
82+
m2 = DensityModel(x -> logpdf(Normal(x[1], x[2]), 1.0))
83+
p2 = Proposal(Static(), [Normal(0,1), InverseGamma(2,3)])
84+
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple})
85+
86+
# Draw from a `NamedTuple` of distributions.
87+
m3 = DensityModel(x -> logpdf(Normal(x.a, x.b), 1.0))
88+
p3 = (a=Proposal(Static(), Normal(0,1)), b=Proposal(Static(), InverseGamma(2,3)))
89+
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
90+
91+
# Draw from a functional proposal.
92+
m4 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
93+
p4 = Proposal(Static(), (x=1.0) -> Normal(x, 1))
94+
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
7695
```
7796

97+
## Static vs. Random Walk
98+
99+
Currently there are only two methods of inference available. Static MH simply draws from the prior, with no
100+
conditioning on the previous sample. Random walk will add the proposal to the previously observed value.
101+
If you are constructing a `Proposal` by hand, you can determine whether the proposal is `Static` or `RandomWalk` using
102+
103+
```julia
104+
static_prop = Proposal(Static(), Normal(0,1))
105+
rw_prop = Proposal(RandomWalk(), Normal(0,1))
106+
```
107+
108+
Different methods are easily composeable. One parameter can be static and another can be a random walk,
109+
each of which may be drawn from separate distributions.
110+
78111
## Multithreaded sampling
79112

80113
AdvancedMH.jl implements the interface of [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl/), which means you get multiple chain sampling
81114
in parallel for free:
82115

83116
```julia
84117
# Sample 4 chains from the posterior.
85-
chain = psample(model, RWMH(init_params), 100000, 4; param_names=["μ","σ"])
118+
chain = psample(model, RWMH(init_params), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
86119
```

src/AdvancedMH.jl

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,31 @@
11
module AdvancedMH
22

33
# Import the relevant libraries.
4-
using Reexport
54
using AbstractMCMC
6-
using Distributions
75
using Random
6+
using Requires
7+
using Distributions
88

99
# Import specific functions and types to use or overload.
10-
import MCMCChains: Chains
11-
import AbstractMCMC: step!, AbstractSampler, AbstractTransition, transition_type, bundle_samples
10+
import AbstractMCMC: step!, AbstractSampler, AbstractTransition,
11+
transition_type, bundle_samples
1212

1313
# Exports
14-
export MetropolisHastings, DensityModel, sample, psample, RWMH, StaticMH
14+
export MetropolisHastings, DensityModel, sample, psample, RWMH, StaticMH, Proposal, Static, RandomWalk
1515

1616
# Abstract type for MH-style samplers.
1717
abstract type Metropolis <: AbstractSampler end
1818
abstract type ProposalStyle end
1919

20+
struct RandomWalk <: ProposalStyle end
21+
struct Static <: ProposalStyle end
22+
2023
# Define a model type. Stores the log density function and the data to
2124
# evaluate the log density on.
2225
"""
2326
DensityModel{F<:Function} <: AbstractModel
2427
25-
`DensityModel` wraps around a self-contained log-liklihood function `ℓπ`.
28+
`DensityModel` wraps around a self-contained log-liklihood function `logdensity`.
2629
2730
Example:
2831
@@ -32,57 +35,73 @@ DensityModel
3235
```
3336
"""
3437
struct DensityModel{F<:Function} <: AbstractModel
35-
ℓπ :: F
38+
logdensity :: F
3639
end
3740

3841
# Create a very basic Transition type, only stores the
3942
# parameter draws and the log probability of the draw.
40-
struct Transition{T<:Union{Vector{<:Real}, <:Real}, L<:Real} <: AbstractTransition
41-
θ :: T
43+
struct Transition{T<:Union{Vector, Real, NamedTuple}, L<:Real} <: AbstractTransition
44+
params :: T
4245
lp :: L
4346
end
4447

4548
# Store the new draw and its log density.
46-
Transition(model::M, θ::T) where {M<:DensityModel, T} = Transition(θ, ℓπ(model, θ))
49+
Transition(model::M, params::T) where {M<:DensityModel, T} = Transition(params, logdensity(model, params))
4750

4851
# Tell the interface what transition type we would like to use.
49-
transition_type(model::DensityModel, spl::Metropolis) = typeof(Transition(spl.init_θ, ℓπ(model, spl.init_θ)))
52+
transition_type(model::DensityModel, spl::Metropolis) = typeof(Transition(spl.init_params, logdensity(model, spl.init_params)))
5053

5154
# Calculate the density of the model given some parameterization.
52-
ℓπ(model::DensityModel, θ::T) where T = model.ℓπ)
53-
ℓπ(model::DensityModel, t::Transition) = t.lp
55+
logdensity(model::DensityModel, params) = model.logdensity(params)
56+
logdensity(model::DensityModel, t::Transition) = t.lp
5457

5558
# A basic chains constructor that works with the Transition struct we defined.
5659
function bundle_samples(
5760
rng::AbstractRNG,
58-
::DensityModel,
61+
model::DensityModel,
5962
s::Metropolis,
6063
N::Integer,
61-
ts::Vector{T};
64+
ts::Type{Any};
6265
param_names=missing,
6366
kwargs...
6467
) where {ModelType<:AbstractModel, T<:AbstractTransition}
65-
# Turn all the transitions into a vector-of-vectors.
66-
vals = copy(reduce(hcat,[vcat(t.θ, t.lp) for t in ts])')
68+
return ts
69+
end
6770

71+
function bundle_samples(
72+
rng::AbstractRNG,
73+
model::DensityModel,
74+
s::Metropolis,
75+
N::Integer,
76+
ts::Vector{T},
77+
chain_type::Type{Vector{NamedTuple}};
78+
param_names=missing,
79+
kwargs...
80+
) where {ModelType<:AbstractModel, T<:AbstractTransition}
6881
# Check if we received any parameter names.
6982
if ismissing(param_names)
70-
param_names = ["Parameter $i" for i in 1:length(s.init_θ)]
83+
param_names = ["param_$i" for i in 1:length(keys(ts[1].params))]
7184
else
7285
# Deepcopy to be thread safe.
7386
param_names = deepcopy(param_names)
7487
end
7588

76-
# Add the log density field to the parameter names.
7789
push!(param_names, "lp")
7890

79-
# Bundle everything up and return a Chains struct.
80-
return Chains(vals, param_names, (internals=["lp"],))
91+
# Turn all the transitions into a vector-of-NamedTuple.
92+
ks = tuple(Symbol.(param_names)...)
93+
nts = [NamedTuple{ks}(tuple(t.params..., t.lp)) for t in ts]
94+
95+
return nts
96+
end
97+
98+
function __init__()
99+
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("mcmcchains-connect.jl")
100+
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("structarray-connect.jl")
81101
end
82102

83103
# Include inference methods.
104+
include("proposal.jl")
84105
include("mh-core.jl")
85-
include("rwmh.jl")
86-
include("staticmh.jl")
87106

88107
end # module AdvancedMH

src/mcmcchains-connect.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import .MCMCChains: Chains
2+
3+
# A basic chains constructor that works with the Transition struct we defined.
4+
function bundle_samples(
5+
rng::AbstractRNG,
6+
model::DensityModel,
7+
s::Metropolis,
8+
N::Integer,
9+
ts::Vector{T},
10+
chain_type::Type{Chains};
11+
param_names=missing,
12+
kwargs...
13+
) where {ModelType<:AbstractModel, T<:AbstractTransition}
14+
# Turn all the transitions into a vector-of-vectors.
15+
vals = copy(reduce(hcat,[vcat(t.params, t.lp) for t in ts])')
16+
17+
# Check if we received any parameter names.
18+
if ismissing(param_names)
19+
param_names = ["Parameter $i" for i in 1:length(s.init_params)]
20+
else
21+
# Deepcopy to be thread safe.
22+
param_names = deepcopy(param_names)
23+
end
24+
25+
# Add the log density field to the parameter names.
26+
push!(param_names, "lp")
27+
28+
# Bundle everything up and return a Chains struct.
29+
return Chains(vals, param_names, (internals=["lp"],))
30+
end

0 commit comments

Comments
 (0)