Skip to content

Commit f8209e6

Browse files
committed
Add tests
1 parent e132adc commit f8209e6

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

test/emcee.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
@testset "emcee.jl" begin
2+
@testset "example" begin
3+
@testset "untransformed space" begin
4+
# define model
5+
function logprob(θ)
6+
s, m = θ
7+
s > 0 || return -Inf
8+
9+
mdist = Normal(0, sqrt(s))
10+
obsdist = Normal(m, sqrt(s))
11+
12+
return logpdf(InverseGamma(2, 3), s) + logpdf(mdist, m) +
13+
logpdf(obsdist, 1.5) + logpdf(obsdist, 2.0)
14+
end
15+
model = DensityModel(logprob)
16+
17+
# perform stretch move and sample from prior in initial step
18+
Random.seed!(100)
19+
sampler = Ensemble(1_000, StretchProposal([InverseGamma(2, 3), Normal(0, 1)]))
20+
chain = sample(model, sampler, 1_000;
21+
param_names = ["s", "m"], chain_type = Chains)
22+
23+
@test mean(chain["s"].value) 49/24 atol=0.1
24+
@test mean(chain["m"].value) 7/6 atol=0.1
25+
end
26+
27+
@testset "transformed space" begin
28+
# define model
29+
function logprob(θ)
30+
logs, m = θ
31+
s = exp(logs)
32+
sqrts = sqrt(s)
33+
34+
mdist = Normal(0, sqrts)
35+
obsdist = Normal(m, sqrts)
36+
37+
return logpdf(InverseGamma(2, 3), s) + logpdf(mdist, m) +
38+
logpdf(obsdist, 1.5) + logpdf(obsdist, 2.0) + logs
39+
end
40+
model = DensityModel(logprob)
41+
42+
# perform stretch move and sample from normal distribution in initial step
43+
Random.seed!(100)
44+
sampler = Ensemble(1_000, StretchProposal(MvNormal(2, 1)))
45+
chain = sample(model, sampler, 1_000;
46+
param_names = ["logs", "m"], chain_type = Chains)
47+
48+
@test mean(exp.(chain["logs"].value)) 49/24 atol=0.1
49+
@test mean(chain["m"].value) 7/6 atol=0.1
50+
end
51+
end
52+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,7 @@ using Test
9797

9898
@test chain1[1].params == val
9999
end
100+
101+
@testset "EMCEE" begin include("emcee.jl") end
100102
end
101103

0 commit comments

Comments
 (0)