Skip to content

Commit 00da11a

Browse files
authored
Use correct iteration numbers in chain (#61)
1 parent 7e70b1c commit 00da11a

File tree

4 files changed

+82
-6
lines changed

4 files changed

+82
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.6.2"
3+
version = "0.6.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/mcmcchains-connect.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ function AbstractMCMC.bundle_samples(
77
sampler::MHSampler,
88
state,
99
chain_type::Type{Chains};
10+
discard_initial=0,
11+
thinning=1,
1012
param_names=missing,
1113
kwargs...
1214
)
@@ -25,7 +27,9 @@ function AbstractMCMC.bundle_samples(
2527
push!(param_names, :lp)
2628

2729
# Bundle everything up and return a Chains struct.
28-
return Chains(vals, param_names, (internals = [:lp],))
30+
return Chains(
31+
vals, param_names, (internals = [:lp],); start=discard_initial + 1, thin=thinning,
32+
)
2933
end
3034

3135
function AbstractMCMC.bundle_samples(
@@ -34,6 +38,8 @@ function AbstractMCMC.bundle_samples(
3438
sampler::MHSampler,
3539
state,
3640
chain_type::Type{Chains};
41+
discard_initial=0,
42+
thinning=1,
3743
param_names=missing,
3844
kwargs...
3945
)
@@ -59,7 +65,9 @@ function AbstractMCMC.bundle_samples(
5965
end
6066

6167
# Bundle everything up and return a Chains struct.
62-
return Chains(vals, param_names, (internals = [:lp],))
68+
return Chains(
69+
vals, param_names, (internals = [:lp],); start=discard_initial + 1, thin=thinning,
70+
)
6371
end
6472

6573
function AbstractMCMC.bundle_samples(
@@ -68,6 +76,8 @@ function AbstractMCMC.bundle_samples(
6876
sampler::Ensemble,
6977
state,
7078
chain_type::Type{Chains};
79+
discard_initial=0,
80+
thinning=1,
7181
param_names=missing,
7282
kwargs...
7383
)
@@ -100,5 +110,7 @@ function AbstractMCMC.bundle_samples(
100110
push!(param_names, :lp)
101111

102112
# Bundle everything up and return a Chains struct.
103-
return Chains(vals, param_names, (internals=[:lp],))
113+
return Chains(
114+
vals, param_names, (internals = [:lp],); start=discard_initial + 1, thin=thinning,
115+
)
104116
end

test/emcee.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,27 @@
1717
# perform stretch move and sample from prior in initial step
1818
Random.seed!(100)
1919
sampler = Ensemble(1_000, StretchProposal([InverseGamma(2, 3), Normal(0, 1)]))
20+
2021
chain = sample(model, sampler, 1_000;
2122
param_names = ["s", "m"], chain_type = Chains)
22-
23+
@test chain isa Chains
24+
@test range(chain) == 1:1_000
2325
@test mean(chain["s"]) 49/24 atol=0.1
2426
@test mean(chain["m"]) 7/6 atol=0.1
27+
28+
chain2 = sample(
29+
model,
30+
sampler,
31+
1_000;
32+
param_names = ["s", "m"],
33+
chain_type = Chains,
34+
discard_initial=25,
35+
thinning=4,
36+
)
37+
@test chain2 isa Chains
38+
@test range(chain2) == range(26; step=4, length=1_000)
39+
@test mean(chain2["s"]) 49/24 atol=0.1
40+
@test mean(chain2["m"]) 7/6 atol=0.1
2541
end
2642

2743
@testset "transformed space" begin
@@ -44,9 +60,24 @@
4460
sampler = Ensemble(1_000, StretchProposal(MvNormal(2, 1)))
4561
chain = sample(model, sampler, 1_000;
4662
param_names = ["logs", "m"], chain_type = Chains)
47-
63+
@test chain isa Chains
64+
@test range(chain) == 1:1_000
4865
@test mean(exp, chain["logs"]) 49/24 atol=0.1
4966
@test mean(chain["m"]) 7/6 atol=0.1
67+
68+
chain2 = sample(
69+
model,
70+
sampler,
71+
1_000;
72+
param_names = ["logs", "m"],
73+
chain_type = Chains,
74+
discard_initial=25,
75+
thinning=4,
76+
)
77+
@test chain2 isa Chains
78+
@test range(chain2) == range(26; step=4, length=1_000)
79+
@test mean(exp, chain2["logs"]) 49/24 atol=0.1
80+
@test mean(chain2["m"]) 7/6 atol=0.1
5081
end
5182
end
5283
end

test/runtests.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,19 @@ include("util.jl")
8080
param_names=["μ", "σ"], chain_type=Chains
8181
)
8282
@test chain1 isa Chains
83+
@test range(chain1) == 1:10_000
8384
@test mean(chain1["μ"]) 0.0 atol=0.1
8485
@test mean(chain1["σ"]) 1.0 atol=0.1
8586

87+
chain1b = sample(
88+
model, StaticMH([Normal(0,1), Normal(0, 1)]), 10_000;
89+
param_names=["μ", "σ"], chain_type=Chains, discard_initial=25, thinning=4,
90+
)
91+
@test chain1b isa Chains
92+
@test range(chain1b) == range(26; step=4, length=10_000)
93+
@test mean(chain1b["μ"]) 0.0 atol=0.1
94+
@test mean(chain1b["σ"]) 1.0 atol=0.1
95+
8696
# NamedTuple of parameters
8797
chain2 = sample(
8898
model,
@@ -92,16 +102,39 @@ include("util.jl")
92102
chain_type=Chains
93103
)
94104
@test chain2 isa Chains
105+
@test range(chain2) == 1:10_000
95106
@test mean(chain2["μ"]) 0.0 atol=0.1
96107
@test mean(chain2["σ"]) 1.0 atol=0.1
97108

109+
chain2b = sample(
110+
model,
111+
MetropolisHastings(
112+
= StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1)))
113+
), 10_000;
114+
chain_type=Chains, discard_initial=25, thinning=4,
115+
)
116+
@test chain2b isa Chains
117+
@test range(chain2b) == range(26; step=4, length=10_000)
118+
@test mean(chain2b["μ"]) 0.0 atol=0.1
119+
@test mean(chain2b["σ"]) 1.0 atol=0.1
120+
98121
# Scalar parameter
99122
chain3 = sample(
100123
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
101124
StaticMH(Normal(0, 1)), 10_000; param_names=["μ"], chain_type=Chains
102125
)
103126
@test chain3 isa Chains
127+
@test range(chain3) == 1:10_000
104128
@test mean(chain3["μ"]) 0.0 atol=0.1
129+
130+
chain3b = sample(
131+
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
132+
StaticMH(Normal(0, 1)), 10_000;
133+
param_names=["μ"], chain_type=Chains, discard_initial=25, thinning=4,
134+
)
135+
@test chain3b isa Chains
136+
@test range(chain3b) == range(26; step=4, length=10_000)
137+
@test mean(chain3b["μ"]) 0.0 atol=0.1
105138
end
106139

107140
@testset "Proposal styles" begin

0 commit comments

Comments
 (0)