Skip to content

Commit 0fd7b9e

Browse files
committed
Fix Chains output for scalar parameters
1 parent 1911b9d commit 0fd7b9e

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

src/mcmcchains-connect.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import .MCMCChains: Chains
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
ts::Vector{<:Transition{<:AbstractArray}},
5+
ts::Vector{<:AbstractTransition},
66
model::DensityModel,
77
sampler::MHSampler,
88
state,
@@ -63,7 +63,7 @@ function AbstractMCMC.bundle_samples(
6363
end
6464

6565
function AbstractMCMC.bundle_samples(
66-
ts::Vector{<:Vector{<:Transition}},
66+
ts::Vector{<:Vector{<:AbstractTransition}},
6767
model::DensityModel,
6868
sampler::Ensemble,
6969
state,

src/structarray-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import .StructArrays: StructArray
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
ts,
5+
ts::Vector{<:AbstractTransition},
66
model::DensityModel,
77
sampler::MHSampler,
88
state,

test/runtests.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,34 @@ include("util.jl")
7474
end
7575

7676
@testset "MCMCChains" begin
77-
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
78-
spl2 = MetropolisHastings((μ = StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1))))
79-
80-
chain1 = sample(model, spl1, 10_000; param_names=["μ", "σ"], chain_type=Chains)
81-
chain2 = sample(model, spl2, 10_000; chain_type=Chains)
82-
77+
# Array of parameters
78+
chain1 = sample(
79+
model, StaticMH([Normal(0,1), Normal(0, 1)]), 10_000;
80+
param_names=["μ", "σ"], chain_type=Chains
81+
)
82+
@test chain1 isa Chains
8383
@test mean(chain1["μ"]) 0.0 atol=0.1
8484
@test mean(chain1["σ"]) 1.0 atol=0.1
8585

86+
# NamedTuple of parameters
87+
chain2 = sample(
88+
model,
89+
MetropolisHastings(
90+
= StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1)))
91+
), 10_000;
92+
chain_type=Chains
93+
)
94+
@test chain2 isa Chains
8695
@test mean(chain2["μ"]) 0.0 atol=0.1
8796
@test mean(chain2["σ"]) 1.0 atol=0.1
97+
98+
# Scalar parameter
99+
chain3 = sample(
100+
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
101+
StaticMH(Normal(0, 1)), 10_000; param_names=["μ"], chain_type=Chains
102+
)
103+
@test chain3 isa Chains
104+
@test mean(chain3["μ"]) 0.0 atol=0.1
88105
end
89106

90107
@testset "Proposal styles" begin
@@ -194,7 +211,6 @@ include("util.jl")
194211
end
195212

196213
@testset "MALA" begin
197-
198214
# Set up the sampler.
199215
sigma = 1e-1
200216
spl1 = MALA(x -> MvNormal((sigma^2 / 2) .* x, sigma))
@@ -203,9 +219,8 @@ include("util.jl")
203219
chain1 = sample(model, spl1, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
204220

205221
@test mean(chain1.μ) 0.0 atol=0.1
206-
@test mean(chain1.σ) 1.0 atol=0.1
222+
@test mean(chain1.σ) 1.0 atol=0.1
207223
end
208224

209225
@testset "EMCEE" begin include("emcee.jl") end
210-
211226
end

0 commit comments

Comments
 (0)