Skip to content

Commit 5f28566

Browse files
committed
Update README
1 parent bf9cb0d commit 5f28566

File tree

1 file changed

+56
-23
lines changed

1 file changed

+56
-23
lines changed

README.md

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ AdvancedMH.jl currently provides a robust implementation of random walk Metropol
44

55
Further development aims to provide a suite of adaptive Metropolis-Hastings implementations.
66

7-
Currently there are two sampler types. The first is `RWMH`, which represents random-walk MH sampling, and the second is `StaticMH`, which draws proposals
8-
from only a prior distribution without incrementing the previous sample.
7+
AdvancedMH works by allowing users to define composable `Proposal` structs in different formats.
98

109
## Usage
1110

12-
AdvancedMH works by accepting some log density function which is used to construct a `DensityModel`. The `DensityModel` is then used in a `sample` call.
11+
First, construct a `DensityModel`, which is a wrapper around the log density function for your inference problem. The `DensityModel` is then used in a `sample` call.
1312

1413
```julia
1514
# Import the package.
1615
using AdvancedMH
1716
using Distributions
17+
using MCMCChains
1818

1919
# Generate a set of data from the posterior we want to estimate.
2020
data = rand(Normal(0, 1), 30)
@@ -27,11 +27,11 @@ density(θ) = insupport(θ) ? sum(logpdf.(dist(θ), data)) : -Inf
2727
# Construct a DensityModel.
2828
model = DensityModel(density)
2929

30-
# Set up our sampler with initial parameters.
31-
spl = RWMH([0.0, 0.0])
30+
# Set up our sampler with a joint multivariate Normal proposal.
31+
spl = RWMH(MvNormal(2,1))
3232

3333
# Sample from the posterior.
34-
chain = sample(model, spl, 100000; param_names=["μ", "σ"])
34+
chain = sample(model, spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
3535
```
3636

3737
Output:
@@ -46,41 +46,74 @@ Samples per chain = 100000
4646
internals = lp
4747
parameters = μ, σ
4848

49-
2-element Array{MCMCChains.ChainDataFrame,1}
49+
2-element Array{ChainDataFrame,1}
5050

5151
Summary Statistics
5252

53-
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │ r_hat │
54-
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │ Any │
55-
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┼─────────┤
56-
1 │ μ │ 0.08341880.241418 0.000763430.003410674693.041.00008
57-
2 │ σ │ 1.33116 0.1841110.000582210.002587784965.831.00001
53+
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │ r_hat │
54+
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │ Any │
55+
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┼─────────┤
56+
1 │ μ │ 0.1561520.19963 0.0006312850.003230333911.731.00009
57+
2 │ σ │ 1.07493 0.1501110.0004746930.002403173707.731.00027
5858

5959
Quantiles
6060

61-
│ Row │ parameters │ 2.5%25.0%50.0%75.0%97.5%
62-
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
63-
├─────┼────────────┼───────────┼────────────┼───────────┼──────────┼──────────┤
64-
1 │ μ │ -0.393769-0.07711340.08016880.2411620.564331
65-
2 │ σ │ 1.036851.20441.309921.436091.75745
61+
│ Row │ parameters │ 2.5%25.0%50.0%75.0%97.5%
62+
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
63+
├─────┼────────────┼──────────┼───────────┼──────────┼──────────┼──────────┤
64+
1 │ μ │ -0.233610.02970060.1591390.2834930.558694
65+
2 │ σ │ 0.8282880.9726821.058041.161551.41349
66+
6667
```
6768

68-
## Custom proposals
69+
## Proposals
6970

70-
Custom proposal distributions can be specified by passing a distribution to `MetropolisHastings`:
71+
AdvancedMH offers various methods of defining your inference problem. Behind the scenes, a `MetropolisHastings` sampler simply holds
72+
some set of `Proposal` structs. AdvancedMH will return posterior samples in the "shape" of the proposal provided -- currently
73+
supported methods are `Array{Proposal}`, `Proposal`, and `NamedTuple{Proposal}`. For example, proposals can be created as:
7174

7275
```julia
73-
# Set up our sampler with initial parameters.
74-
spl1 = RWMH([0.0, 0.0], MvNormal(2, 0.5))
75-
spl2 = StaticMH([0.0, 0.0], MvNormal(2, 0.5))
76+
# Provide a univariate proposal.
77+
m1 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
78+
p1 = Proposal(Static(), Normal(0,1))
79+
c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple})
80+
81+
# Draw from a vector of distributions.
82+
m2 = DensityModel(x -> logpdf(Normal(x[1], x[2]), 1.0))
83+
p2 = Proposal(Static(), [Normal(0,1), InverseGamma(2,3)])
84+
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple})
85+
86+
# Draw from a `NamedTuple` of distributions.
87+
m3 = DensityModel(x -> logpdf(Normal(x.a, x.b), 1.0))
88+
p3 = (a=Proposal(Static(), Normal(0,1)), b=Proposal(Static(), InverseGamma(2,3)))
89+
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
90+
91+
# Draw from a functional proposal.
92+
m4 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
93+
p4 = Proposal(Static(), (x=1.0) -> Normal(x, 1))
94+
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
7695
```
7796

97+
## Static vs. Random Walk
98+
99+
Currently there are only two methods of inference available. Static MH simply draws from the prior, with no
100+
conditioning on the previous sample. Random walk will add the proposal to the previously observed value.
101+
If you are constructing a `Proposal` by hand, you can determine whether the proposal is `Static` or `RandomWalk` using
102+
103+
```julia
104+
static_prop = Proposal(Static(), Normal(0,1))
105+
rw_prop = Proposal(RandomWalk(), Normal(0,1))
106+
```
107+
108+
Different methods are easily composeable. One parameter can be static and another can be a random walk,
109+
each of which may be drawn from separate distributions.
110+
78111
## Multithreaded sampling
79112

80113
AdvancedMH.jl implements the interface of [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl/), which means you get multiple chain sampling
81114
in parallel for free:
82115

83116
```julia
84117
# Sample 4 chains from the posterior.
85-
chain = psample(model, RWMH(init_params), 100000, 4; param_names=["μ","σ"])
118+
chain = psample(model, RWMH(init_params), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
86119
```

0 commit comments

Comments
 (0)