Skip to content

Commit d4d97b7

Browse files
JaimeRZPtorfjeldedevmotion
authored
convinience constructors (#83)
* convinience constructors * bug * I --> ones(d) * Update src/mh-core.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update src/mh-core.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Linear Alg + add previously removed test * project * Update src/AdvancedMH.jl Co-authored-by: David Widmann <[email protected]> * FillArrays * more samples for RWMH * Compat * Update Project.toml Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent ae82171 commit d4d97b7

File tree

4 files changed

+15
-0
lines changed

4 files changed

+15
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ version = "0.7.4"
55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
911
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1012
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -24,6 +26,7 @@ AdvancedMHStructArraysExt = "StructArrays"
2426
AbstractMCMC = "4"
2527
DiffResults = "1"
2628
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
29+
FillArrays = "1"
2730
ForwardDiff = "0.10"
2831
LogDensityProblems = "2"
2932
MCMCChains = "5, 6"

src/AdvancedMH.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module AdvancedMH
33
# Import the relevant libraries.
44
using AbstractMCMC
55
using Distributions
6+
using LinearAlgebra: I
7+
using FillArrays: Zeros
68

79
using LogDensityProblems: LogDensityProblems
810

src/mh-core.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ struct MetropolisHastings{D} <: MHSampler
4646
end
4747

4848
StaticMH(d) = MetropolisHastings(StaticProposal(d))
49+
StaticMH(d::Int) = MetropolisHastings(StaticProposal(MvNormal(Zeros(d), I)))
4950
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
51+
RWMH(d::Int) = MetropolisHastings(RandomWalkProposal(MvNormal(Zeros(d), I)))
5052

5153
function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModelOrLogDensityModel)
5254
return propose(rng, sampler.proposal, model)

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,32 +37,40 @@ include("util.jl")
3737
# Set up our sampler with initial parameters.
3838
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
3939
spl2 = StaticMH(MvNormal(zeros(2), I))
40+
spl3 = StaticMH(2)
4041

4142
# Sample from the posterior.
4243
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"])
4344
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"])
45+
chain3 = sample(model, spl3, 100000; chain_type=StructArray, param_names=["μ", "σ"])
4446

4547
# chn_mean ≈ dist_mean atol=atol_v
4648
@test mean(chain1.μ) 0.0 atol=0.1
4749
@test mean(chain1.σ) 1.0 atol=0.1
4850
@test mean(chain2.μ) 0.0 atol=0.1
4951
@test mean(chain2.σ) 1.0 atol=0.1
52+
@test mean(chain3.μ) 0.0 atol=0.1
53+
@test mean(chain3.σ) 1.0 atol=0.1
5054
end
5155

5256
@testset "RandomWalk" begin
5357
# Set up our sampler with initial parameters.
5458
spl1 = RWMH([Normal(0,1), Normal(0, 1)])
5559
spl2 = RWMH(MvNormal(zeros(2), I))
60+
spl3 = RWMH(2)
5661

5762
# Sample from the posterior.
5863
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"])
5964
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"])
65+
chain3 = sample(model, spl3, 200000; chain_type=StructArray, param_names=["μ", "σ"])
6066

6167
# chn_mean ≈ dist_mean atol=atol_v
6268
@test mean(chain1.μ) 0.0 atol=0.1
6369
@test mean(chain1.σ) 1.0 atol=0.1
6470
@test mean(chain2.μ) 0.0 atol=0.1
6571
@test mean(chain2.σ) 1.0 atol=0.1
72+
@test mean(chain3.μ) 0.0 atol=0.1
73+
@test mean(chain3.σ) 1.0 atol=0.1
6674
end
6775

6876
@testset "parallel sampling" begin

0 commit comments

Comments
 (0)