Skip to content

Commit 10766d4

Browse files
authored
Merge pull request #31 from devmotion/mcmcchains
Update for MCMCChains 4
2 parents ba894e5 + a99fdac commit 10766d4

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

src/mcmcchains-connect.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ function AbstractMCMC.bundle_samples(
1212
kwargs...
1313
)
1414
# Turn all the transitions into a vector-of-vectors.
15-
vals = copy(reduce(hcat,[vcat(t.params, t.lp) for t in ts])')
15+
vals = [vcat(t.params, t.lp) for t in ts]
1616

1717
# Check if we received any parameter names.
1818
if ismissing(param_names)
19-
param_names = ["param_$i" for i in 1:length(s.init_params)]
19+
param_names = [Symbol(:param_, i) for i in 1:length(s.init_params)]
2020
else
21-
# Deepcopy to be thread safe.
22-
param_names = deepcopy(param_names)
21+
# Generate new array to be thread safe.
22+
param_names = Symbol.(param_names)
2323
end
2424

2525
# Add the log density field to the parameter names.
26-
push!(param_names, "lp")
26+
push!(param_names, :lp)
2727

2828
# Bundle everything up and return a Chains struct.
29-
return Chains(vals, param_names, (internals=["lp"],))
29+
return Chains(vals, param_names, (internals = [:lp],))
3030
end
3131

3232
function AbstractMCMC.bundle_samples(
@@ -56,15 +56,15 @@ function AbstractMCMC.bundle_samples(
5656

5757
# Check if we received any parameter names.
5858
if ismissing(param_names)
59-
param_names = ["param_$i" for i in 1:length(ts[1][1].params)]
59+
param_names = [Symbol(:param_, i) for i in 1:length(ts[1][1].params)]
6060
else
61-
# Deepcopy to be thread safe.
62-
param_names = deepcopy(param_names)
61+
# Generate new array to be thread safe.
62+
param_names = Symbol.(param_names)
6363
end
6464

6565
# Add the log density field to the parameter names.
66-
push!(param_names, "lp")
66+
push!(param_names, :lp)
6767

6868
# Bundle everything up and return a Chains struct.
69-
return Chains(vals, param_names, (internals=["lp"],))
70-
end
69+
return Chains(vals, param_names, (internals=[:lp],))
70+
end

test/emcee.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
chain = sample(model, sampler, 1_000;
2121
param_names = ["s", "m"], chain_type = Chains)
2222

23-
@test mean(chain["s"].value) 49/24 atol=0.1
24-
@test mean(chain["m"].value) 7/6 atol=0.1
23+
@test mean(chain["s"]) 49/24 atol=0.1
24+
@test mean(chain["m"]) 7/6 atol=0.1
2525
end
2626

2727
@testset "transformed space" begin
@@ -45,8 +45,8 @@
4545
chain = sample(model, sampler, 1_000;
4646
param_names = ["logs", "m"], chain_type = Chains)
4747

48-
@test mean(exp.(chain["logs"].value)) 49/24 atol=0.1
49-
@test mean(chain["m"].value) 7/6 atol=0.1
48+
@test mean(exp, chain["logs"]) 49/24 atol=0.1
49+
@test mean(chain["m"]) 7/6 atol=0.1
5050
end
5151
end
5252
end

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ using Test
5858

5959
chain1 = sample(model, spl1, MCMCDistributed(), 10000, 4;
6060
param_names=["μ", "σ"], chain_type=Chains)
61-
@test mean(chain1["μ"].value) 0.0 atol=0.1
62-
@test mean(chain1["σ"].value) 1.0 atol=0.1
61+
@test mean(chain1["μ"]) 0.0 atol=0.1
62+
@test mean(chain1["σ"]) 1.0 atol=0.1
6363

6464
if VERSION >= v"1.3"
6565
chain2 = sample(model, spl1, MCMCThreads(), 10000, 4;
6666
param_names=["μ", "σ"], chain_type=Chains)
67-
@test mean(chain2["μ"].value) 0.0 atol=0.1
68-
@test mean(chain2["σ"].value) 1.0 atol=0.1
67+
@test mean(chain2["μ"]) 0.0 atol=0.1
68+
@test mean(chain2["σ"]) 1.0 atol=0.1
6969
end
7070
end
7171

0 commit comments

Comments
 (0)