Skip to content

Commit 5aa196c

Browse files
authored
Fix for issue #95 (#96)
* fix for issue #95 * fixed failing tests * remove usage of stack and stop logging so much progress * removed more progress logging in test suite * removed more progress logging * bump patch version
1 parent b7829cb commit 5aa196c

File tree

4 files changed

+94
-45
lines changed

4 files changed

+94
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.8.1"
3+
version = "0.8.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/MALA.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ function AbstractMCMC.step(
6262

6363
# Compute the log ratio of proposal densities.
6464
logratio_proposal_density = q(
65-
proposal(-gradient_logdensity_candidate), state, candidate
66-
) - q(proposal(-gradient_logdensity_state), candidate, state)
65+
proposal(gradient_logdensity_candidate), state, candidate
66+
) - q(proposal(gradient_logdensity_state), candidate, state)
6767

6868
# Compute the log acceptance probability.
6969
logα = logdensity_candidate - logdensity_state + logratio_proposal_density
@@ -72,10 +72,7 @@ function AbstractMCMC.step(
7272
transition = if -Random.randexp(rng) < logα
7373
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate, true)
7474
else
75-
candidate = transition_prev.params
76-
lp = transition_prev.lp
77-
gradient = transition_prev.gradient
78-
GradientTransition(candidate, lp, gradient, false)
75+
GradientTransition(transition_prev.params, transition_prev.lp, transition_prev.gradient, false)
7976
end
8077

8178
return transition, transition

test/emcee.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
sampler = Ensemble(1_000, StretchProposal([InverseGamma(2, 3), Normal(0, 1)]))
2020

2121
chain = sample(model, sampler, 1_000;
22-
param_names = ["s", "m"], chain_type = Chains)
22+
param_names = ["s", "m"], chain_type = Chains, progress=false)
2323
@test chain isa Chains
2424
@test range(chain) == 1:1_000
2525
@test mean(chain["s"]) 49/24 atol=0.1
@@ -33,6 +33,7 @@
3333
chain_type = Chains,
3434
discard_initial=25,
3535
thinning=4,
36+
progress=false
3637
)
3738
@test chain2 isa Chains
3839
@test range(chain2) == range(26; step=4, length=1_000)
@@ -59,7 +60,7 @@
5960
Random.seed!(100)
6061
sampler = Ensemble(1_000, StretchProposal(MvNormal(zeros(2), I)))
6162
chain = sample(model, sampler, 1_000;
62-
param_names = ["logs", "m"], chain_type = Chains)
63+
param_names = ["logs", "m"], chain_type = Chains, progress=false)
6364
@test chain isa Chains
6465
@test range(chain) == 1:1_000
6566
@test mean(exp, chain["logs"]) 49/24 atol=0.1
@@ -73,6 +74,7 @@
7374
chain_type = Chains,
7475
discard_initial=25,
7576
thinning=4,
77+
progress=false
7678
)
7779
@test chain2 isa Chains
7880
@test range(chain2) == range(26; step=4, length=1_000)

test/runtests.jl

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ include("util.jl")
4040
spl3 = StaticMH(2)
4141

4242
# Sample from the posterior.
43-
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"])
44-
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"])
45-
chain3 = sample(model, spl3, 100000; chain_type=StructArray, param_names=["μ", "σ"])
43+
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
44+
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
45+
chain3 = sample(model, spl3, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
4646

4747
# chn_mean ≈ dist_mean atol=atol_v
4848
@test mean(chain1.μ) 0.0 atol=0.1
@@ -60,9 +60,9 @@ include("util.jl")
6060
spl3 = RWMH(2)
6161

6262
# Sample from the posterior.
63-
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"])
64-
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"])
65-
chain3 = sample(model, spl3, 200000; chain_type=StructArray, param_names=["μ", "σ"])
63+
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
64+
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
65+
chain3 = sample(model, spl3, 200000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
6666

6767
# chn_mean ≈ dist_mean atol=atol_v
6868
@test mean(chain1.μ) 0.0 atol=0.1
@@ -77,13 +77,13 @@ include("util.jl")
7777
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
7878

7979
chain1 = sample(model, spl1, MCMCDistributed(), 10000, 4;
80-
param_names=["μ", "σ"], chain_type=Chains)
80+
param_names=["μ", "σ"], chain_type=Chains, progress=false)
8181
@test mean(chain1["μ"]) 0.0 atol=0.1
8282
@test mean(chain1["σ"]) 1.0 atol=0.1
8383

8484
if VERSION >= v"1.3"
8585
chain2 = sample(model, spl1, MCMCThreads(), 10000, 4;
86-
param_names=["μ", "σ"], chain_type=Chains)
86+
param_names=["μ", "σ"], chain_type=Chains, progress=false)
8787
@test mean(chain2["μ"]) 0.0 atol=0.1
8888
@test mean(chain2["σ"]) 1.0 atol=0.1
8989
end
@@ -93,7 +93,7 @@ include("util.jl")
9393
# Array of parameters
9494
chain1 = sample(
9595
model, StaticMH([Normal(0,1), Normal(0, 1)]), 10_000;
96-
param_names=["μ", "σ"], chain_type=Chains
96+
param_names=["μ", "σ"], chain_type=Chains, progress=false
9797
)
9898
@test chain1 isa Chains
9999
@test range(chain1) == 1:10_000
@@ -103,6 +103,7 @@ include("util.jl")
103103
chain1b = sample(
104104
model, StaticMH([Normal(0,1), Normal(0, 1)]), 10_000;
105105
param_names=["μ", "σ"], chain_type=Chains, discard_initial=25, thinning=4,
106+
progress=false
106107
)
107108
@test chain1b isa Chains
108109
@test range(chain1b) == range(26; step=4, length=10_000)
@@ -115,7 +116,8 @@ include("util.jl")
115116
MetropolisHastings(
116117
= StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1)))
117118
), 10_000;
118-
chain_type=Chains
119+
chain_type=Chains,
120+
progress=false
119121
)
120122
@test chain2 isa Chains
121123
@test range(chain2) == 1:10_000
@@ -128,6 +130,7 @@ include("util.jl")
128130
= StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1)))
129131
), 10_000;
130132
chain_type=Chains, discard_initial=25, thinning=4,
133+
progress=false
131134
)
132135
@test chain2b isa Chains
133136
@test range(chain2b) == range(26; step=4, length=10_000)
@@ -137,7 +140,8 @@ include("util.jl")
137140
# Scalar parameter
138141
chain3 = sample(
139142
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
140-
StaticMH(Normal(0, 1)), 10_000; param_names=["μ"], chain_type=Chains
143+
StaticMH(Normal(0, 1)), 10_000; param_names=["μ"], chain_type=Chains,
144+
progress=false
141145
)
142146
@test chain3 isa Chains
143147
@test range(chain3) == 1:10_000
@@ -147,6 +151,7 @@ include("util.jl")
147151
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
148152
StaticMH(Normal(0, 1)), 10_000;
149153
param_names=["μ"], chain_type=Chains, discard_initial=25, thinning=4,
154+
progress=false
150155
)
151156
@test chain3b isa Chains
152157
@test range(chain3b) == range(26; step=4, length=10_000)
@@ -164,10 +169,10 @@ include("util.jl")
164169
p3 = (a=StaticProposal(Normal(0,1)), b=StaticProposal(InverseGamma(2,3)))
165170
p4 = StaticProposal((x=1.0) -> Normal(x, 1))
166171

167-
c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple})
168-
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple})
169-
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
170-
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
172+
c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple}, progress=false)
173+
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple}, progress=false)
174+
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple}, progress=false)
175+
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple}, progress=false)
171176

172177
@test keys(c1[1]) == (:param_1, :lp)
173178
@test keys(c2[1]) == (:param_1, :param_2, :lp)
@@ -182,7 +187,7 @@ include("util.jl")
182187
val = [0.4, 1.2]
183188

184189
# Sample from the posterior.
185-
chain1 = sample(model, spl1, 10, initial_params = val)
190+
chain1 = sample(model, spl1, 10, initial_params = val, progress=false)
186191

187192
@test chain1[1].params == val
188193
end
@@ -199,12 +204,12 @@ include("util.jl")
199204
p1 = RandomWalkProposal(CustomNormal())
200205
@test p1 isa RandomWalkProposal{false}
201206
@test_throws MethodError AdvancedMH.logratio_proposal_density(p1, randn(), randn())
202-
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10)
207+
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10, progress=false)
203208

204209
p1 = StaticProposal(x -> CustomNormal(x))
205210
@test p1 isa StaticProposal{false}
206211
@test_throws MethodError AdvancedMH.logratio_proposal_density(p1, randn(), randn())
207-
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10)
212+
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10, progress=false)
208213

209214
# If the proposal is declared to be symmetric, the log ratio of the proposal
210215
# density is not evaluated.
@@ -227,7 +232,8 @@ include("util.jl")
227232
))
228233
chain1 = sample(
229234
m1, MetropolisHastings(p2), 100000;
230-
chain_type=StructArray, param_names=["x"]
235+
chain_type=StructArray, param_names=["x"],
236+
progress=false
231237
)
232238
@test mean(chain1.x) mean(d1) atol=0.05
233239
@test std(chain1.x) std(d1) atol=0.05
@@ -260,29 +266,73 @@ include("util.jl")
260266
end
261267

262268
@testset "MALA" begin
263-
# Set up the sampler.
264-
σ² = 0.01
265-
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
269+
@testset "basic" begin
270+
# Set up the sampler.
271+
σ² = 1e-3
272+
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
266273

267-
# Sample from the posterior with initial parameters.
268-
chain1 = sample(model, spl1, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
274+
# Sample from the posterior with initial parameters.
275+
chain1 = sample(
276+
model, spl1, 1000;
277+
initial_params=ones(2),
278+
chain_type=StructArray,
279+
param_names=["μ", "σ"],
280+
discard_initial=100,
281+
progress=false
282+
)
269283

270-
@test mean(chain1.μ) 0.0 atol=0.1
271-
@test mean(chain1.σ) 1.0 atol=0.1
284+
@test mean(chain1.μ) 0.0 atol = 0.1
285+
@test mean(chain1.σ) 1.0 atol = 0.1
286+
287+
@testset "LogDensityProblems interface" begin
288+
admodel = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), density)
289+
chain2 = sample(
290+
admodel,
291+
spl1,
292+
1000;
293+
initial_params=ones(2),
294+
chain_type=StructArray,
295+
param_names=["μ", "σ"],
296+
discard_initial=100,
297+
progress=false
298+
)
299+
300+
@test mean(chain2.μ) 0.0 atol = 0.1
301+
@test mean(chain2.σ) 1.0 atol = 0.1
302+
end
303+
end
272304

273-
@testset "LogDensityProblems interface" begin
274-
admodel = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), density)
275-
chain2 = sample(
276-
admodel,
277-
spl1,
278-
100000;
305+
@testset "issue #95" begin
306+
struct TheNormalLogDensity{M}
307+
A::M
308+
end
309+
310+
# can do gradient
311+
LogDensityProblems.capabilities(::Type{<:TheNormalLogDensity}) = LogDensityProblems.LogDensityOrder{1}()
312+
313+
LogDensityProblems.dimension(d::TheNormalLogDensity) = size(d.A, 1)
314+
LogDensityProblems.logdensity(d::TheNormalLogDensity, x) = -x' * d.A * x / 2
315+
316+
function LogDensityProblems.logdensity_and_gradient(d::TheNormalLogDensity, x)
317+
return -x' * d.A * x / 2, -d.A * x
318+
end
319+
320+
Σ = [1.5 0.35; 0.35 1.0]
321+
σ² = 0.5
322+
spl = AdvancedMH.MALA(g -> Distributions.MvNormal((σ² / 2) .* g, σ² * I))
323+
324+
chain = sample(
325+
TheNormalLogDensity(inv(Σ)),
326+
spl,
327+
500000;
279328
initial_params=ones(2),
280-
chain_type=StructArray,
281-
param_names=["μ", "σ"]
329+
progress=false
282330
)
331+
data = mapreduce(Base.Fix2(getproperty, :params), hcat, chain)
332+
Σ_est = cov(data, dims=2)
283333

284-
@test mean(chain2.μ) 0.0 atol=0.1
285-
@test mean(chain2.σ) 1.0 atol=0.1
334+
@test mean(data, dims=2) zeros(2) atol = 0.1
335+
@test Σ Σ_est atol = 2e-1
286336
end
287337
end
288338

0 commit comments

Comments
 (0)