@@ -40,9 +40,9 @@ include("util.jl")
4040 spl3 = StaticMH (2 )
4141
4242 # Sample from the posterior.
43- chain1 = sample (model, spl1, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ])
44- chain2 = sample (model, spl2, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ])
45- chain3 = sample (model, spl3, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ])
43+ chain1 = sample (model, spl1, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ], progress = false )
44+ chain2 = sample (model, spl2, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ], progress = false )
45+ chain3 = sample (model, spl3, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ], progress = false )
4646
4747 # chn_mean ≈ dist_mean atol=atol_v
4848 @test mean (chain1. μ) ≈ 0.0 atol= 0.1
@@ -60,9 +60,9 @@ include("util.jl")
6060 spl3 = RWMH (2 )
6161
6262 # Sample from the posterior.
63- chain1 = sample (model, spl1, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ])
64- chain2 = sample (model, spl2, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ])
65- chain3 = sample (model, spl3, 200000 ; chain_type= StructArray, param_names= [" μ" , " σ" ])
63+ chain1 = sample (model, spl1, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ], progress = false )
64+ chain2 = sample (model, spl2, 100000 ; chain_type= StructArray, param_names= [" μ" , " σ" ], progress = false )
65+ chain3 = sample (model, spl3, 200000 ; chain_type= StructArray, param_names= [" μ" , " σ" ], progress = false )
6666
6767 # chn_mean ≈ dist_mean atol=atol_v
6868 @test mean (chain1. μ) ≈ 0.0 atol= 0.1
@@ -77,13 +77,13 @@ include("util.jl")
7777 spl1 = StaticMH ([Normal (0 ,1 ), Normal (0 , 1 )])
7878
7979 chain1 = sample (model, spl1, MCMCDistributed (), 10000 , 4 ;
80- param_names= [" μ" , " σ" ], chain_type= Chains)
80+ param_names= [" μ" , " σ" ], chain_type= Chains, progress = false )
8181 @test mean (chain1[" μ" ]) ≈ 0.0 atol= 0.1
8282 @test mean (chain1[" σ" ]) ≈ 1.0 atol= 0.1
8383
8484 if VERSION >= v " 1.3"
8585 chain2 = sample (model, spl1, MCMCThreads (), 10000 , 4 ;
86- param_names= [" μ" , " σ" ], chain_type= Chains)
86+ param_names= [" μ" , " σ" ], chain_type= Chains, progress = false )
8787 @test mean (chain2[" μ" ]) ≈ 0.0 atol= 0.1
8888 @test mean (chain2[" σ" ]) ≈ 1.0 atol= 0.1
8989 end
@@ -93,7 +93,7 @@ include("util.jl")
9393 # Array of parameters
9494 chain1 = sample (
9595 model, StaticMH ([Normal (0 ,1 ), Normal (0 , 1 )]), 10_000 ;
96- param_names= [" μ" , " σ" ], chain_type= Chains
96+ param_names= [" μ" , " σ" ], chain_type= Chains, progress = false
9797 )
9898 @test chain1 isa Chains
9999 @test range (chain1) == 1 : 10_000
@@ -103,6 +103,7 @@ include("util.jl")
103103 chain1b = sample (
104104 model, StaticMH ([Normal (0 ,1 ), Normal (0 , 1 )]), 10_000 ;
105105 param_names= [" μ" , " σ" ], chain_type= Chains, discard_initial= 25 , thinning= 4 ,
106+ progress= false
106107 )
107108 @test chain1b isa Chains
108109 @test range (chain1b) == range (26 ; step= 4 , length= 10_000 )
@@ -115,7 +116,8 @@ include("util.jl")
115116 MetropolisHastings (
116117 (μ = StaticProposal (Normal (0 ,1 )), σ = StaticProposal (Normal (0 , 1 )))
117118 ), 10_000 ;
118- chain_type= Chains
119+ chain_type= Chains,
120+ progress= false
119121 )
120122 @test chain2 isa Chains
121123 @test range (chain2) == 1 : 10_000
@@ -128,6 +130,7 @@ include("util.jl")
128130 (μ = StaticProposal (Normal (0 ,1 )), σ = StaticProposal (Normal (0 , 1 )))
129131 ), 10_000 ;
130132 chain_type= Chains, discard_initial= 25 , thinning= 4 ,
133+ progress= false
131134 )
132135 @test chain2b isa Chains
133136 @test range (chain2b) == range (26 ; step= 4 , length= 10_000 )
@@ -137,7 +140,8 @@ include("util.jl")
137140 # Scalar parameter
138141 chain3 = sample (
139142 DensityModel (x -> loglikelihood (Normal (x, 1 ), data)),
140- StaticMH (Normal (0 , 1 )), 10_000 ; param_names= [" μ" ], chain_type= Chains
143+ StaticMH (Normal (0 , 1 )), 10_000 ; param_names= [" μ" ], chain_type= Chains,
144+ progress= false
141145 )
142146 @test chain3 isa Chains
143147 @test range (chain3) == 1 : 10_000
@@ -147,6 +151,7 @@ include("util.jl")
147151 DensityModel (x -> loglikelihood (Normal (x, 1 ), data)),
148152 StaticMH (Normal (0 , 1 )), 10_000 ;
149153 param_names= [" μ" ], chain_type= Chains, discard_initial= 25 , thinning= 4 ,
154+ progress= false
150155 )
151156 @test chain3b isa Chains
152157 @test range (chain3b) == range (26 ; step= 4 , length= 10_000 )
@@ -164,10 +169,10 @@ include("util.jl")
164169 p3 = (a= StaticProposal (Normal (0 ,1 )), b= StaticProposal (InverseGamma (2 ,3 )))
165170 p4 = StaticProposal ((x= 1.0 ) -> Normal (x, 1 ))
166171
167- c1 = sample (m1, MetropolisHastings (p1), 100 ; chain_type= Vector{NamedTuple})
168- c2 = sample (m2, MetropolisHastings (p2), 100 ; chain_type= Vector{NamedTuple})
169- c3 = sample (m3, MetropolisHastings (p3), 100 ; chain_type= Vector{NamedTuple})
170- c4 = sample (m4, MetropolisHastings (p4), 100 ; chain_type= Vector{NamedTuple})
172+ c1 = sample (m1, MetropolisHastings (p1), 100 ; chain_type= Vector{NamedTuple}, progress = false )
173+ c2 = sample (m2, MetropolisHastings (p2), 100 ; chain_type= Vector{NamedTuple}, progress = false )
174+ c3 = sample (m3, MetropolisHastings (p3), 100 ; chain_type= Vector{NamedTuple}, progress = false )
175+ c4 = sample (m4, MetropolisHastings (p4), 100 ; chain_type= Vector{NamedTuple}, progress = false )
171176
172177 @test keys (c1[1 ]) == (:param_1 , :lp )
173178 @test keys (c2[1 ]) == (:param_1 , :param_2 , :lp )
@@ -182,7 +187,7 @@ include("util.jl")
182187 val = [0.4 , 1.2 ]
183188
184189 # Sample from the posterior.
185- chain1 = sample (model, spl1, 10 , initial_params = val)
190+ chain1 = sample (model, spl1, 10 , initial_params = val, progress = false )
186191
187192 @test chain1[1 ]. params == val
188193 end
@@ -199,12 +204,12 @@ include("util.jl")
199204 p1 = RandomWalkProposal (CustomNormal ())
200205 @test p1 isa RandomWalkProposal{false }
201206 @test_throws MethodError AdvancedMH. logratio_proposal_density (p1, randn (), randn ())
202- @test_throws MethodError sample (m1, MetropolisHastings (p1), 10 )
207+ @test_throws MethodError sample (m1, MetropolisHastings (p1), 10 , progress = false )
203208
204209 p1 = StaticProposal (x -> CustomNormal (x))
205210 @test p1 isa StaticProposal{false }
206211 @test_throws MethodError AdvancedMH. logratio_proposal_density (p1, randn (), randn ())
207- @test_throws MethodError sample (m1, MetropolisHastings (p1), 10 )
212+ @test_throws MethodError sample (m1, MetropolisHastings (p1), 10 , progress = false )
208213
209214 # If the proposal is declared to be symmetric, the log ratio of the proposal
210215 # density is not evaluated.
@@ -227,7 +232,8 @@ include("util.jl")
227232 ))
228233 chain1 = sample (
229234 m1, MetropolisHastings (p2), 100000 ;
230- chain_type= StructArray, param_names= [" x" ]
235+ chain_type= StructArray, param_names= [" x" ],
236+ progress= false
231237 )
232238 @test mean (chain1. x) ≈ mean (d1) atol= 0.05
233239 @test std (chain1. x) ≈ std (d1) atol= 0.05
@@ -260,29 +266,73 @@ include("util.jl")
260266 end
261267
262268 @testset " MALA" begin
263- # Set up the sampler.
264- σ² = 0.01
265- spl1 = MALA (x -> MvNormal ((σ² / 2 ) .* x, σ² * I))
269+ @testset " basic" begin
270+ # Set up the sampler.
271+ σ² = 1e-3
272+ spl1 = MALA (x -> MvNormal ((σ² / 2 ) .* x, σ² * I))
266273
267- # Sample from the posterior with initial parameters.
268- chain1 = sample (model, spl1, 100000 ; initial_params= ones (2 ), chain_type= StructArray, param_names= [" μ" , " σ" ])
274+ # Sample from the posterior with initial parameters.
275+ chain1 = sample (
276+ model, spl1, 1000 ;
277+ initial_params= ones (2 ),
278+ chain_type= StructArray,
279+ param_names= [" μ" , " σ" ],
280+ discard_initial= 100 ,
281+ progress= false
282+ )
269283
270- @test mean (chain1. μ) ≈ 0.0 atol= 0.1
271- @test mean (chain1. σ) ≈ 1.0 atol= 0.1
284+ @test mean (chain1. μ) ≈ 0.0 atol = 0.1
285+ @test mean (chain1. σ) ≈ 1.0 atol = 0.1
286+
287+ @testset " LogDensityProblems interface" begin
288+ admodel = LogDensityProblemsAD. ADgradient (Val (:ForwardDiff ), density)
289+ chain2 = sample (
290+ admodel,
291+ spl1,
292+ 1000 ;
293+ initial_params= ones (2 ),
294+ chain_type= StructArray,
295+ param_names= [" μ" , " σ" ],
296+ discard_initial= 100 ,
297+ progress= false
298+ )
299+
300+ @test mean (chain2. μ) ≈ 0.0 atol = 0.1
301+ @test mean (chain2. σ) ≈ 1.0 atol = 0.1
302+ end
303+ end
272304
273- @testset " LogDensityProblems interface" begin
274- admodel = LogDensityProblemsAD. ADgradient (Val (:ForwardDiff ), density)
275- chain2 = sample (
276- admodel,
277- spl1,
278- 100000 ;
305+ @testset " issue #95" begin
306+ struct TheNormalLogDensity{M}
307+ A:: M
308+ end
309+
310+ # can do gradient
311+ LogDensityProblems. capabilities (:: Type{<:TheNormalLogDensity} ) = LogDensityProblems. LogDensityOrder {1} ()
312+
313+ LogDensityProblems. dimension (d:: TheNormalLogDensity ) = size (d. A, 1 )
314+ LogDensityProblems. logdensity (d:: TheNormalLogDensity , x) = - x' * d. A * x / 2
315+
316+ function LogDensityProblems. logdensity_and_gradient (d:: TheNormalLogDensity , x)
317+ return - x' * d. A * x / 2 , - d. A * x
318+ end
319+
320+ Σ = [1.5 0.35 ; 0.35 1.0 ]
321+ σ² = 0.5
322+ spl = AdvancedMH. MALA (g -> Distributions. MvNormal ((σ² / 2 ) .* g, σ² * I))
323+
324+ chain = sample (
325+ TheNormalLogDensity (inv (Σ)),
326+ spl,
327+ 500000 ;
279328 initial_params= ones (2 ),
280- chain_type= StructArray,
281- param_names= [" μ" , " σ" ]
329+ progress= false
282330 )
331+ data = mapreduce (Base. Fix2 (getproperty, :params ), hcat, chain)
332+ Σ_est = cov (data, dims= 2 )
283333
284- @test mean (chain2 . μ ) ≈ 0.0 atol= 0.1
285- @test mean (chain2 . σ) ≈ 1.0 atol= 0. 1
334+ @test mean (data, dims = 2 ) ≈ zeros ( 2 ) atol = 0.1
335+ @test Σ ≈ Σ_est atol = 2e- 1
286336 end
287337 end
288338
0 commit comments