Skip to content

Commit 6909f74

Browse files
authored
Make chainsstack([c]) put info fields into vectors (#492)
* Fix chainscat for single-chain * Bump patch * Bump minor instead, export chainsstack, add tests
1 parent 1a03a99 commit 6909f74

File tree

4 files changed

+28
-7
lines changed

4 files changed

+28
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "Chain types and utility functions for MCMC simulations."
6-
version = "7.2.2"
6+
version = "7.3.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/MCMCChains.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MCMCChains
33
using AxisArrays
44
const axes = Base.axes
55
import AbstractMCMC
6-
import AbstractMCMC: chainscat
6+
import AbstractMCMC: chainscat, chainsstack
77
using Distributions
88
using RecipesBase
99
using Dates
@@ -35,7 +35,7 @@ import LinearAlgebra
3535
import Random
3636
import Statistics: std, cor, mean, var, mean!
3737

38-
export Chains, chains, chainscat
38+
export Chains, chains, chainscat, chainsstack
3939
export setrange, resetrange
4040
export set_section, get_params, sections, sort_sections, setinfo
4141
export replacenames, namesingroup, group

src/chains.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ Base.hcat(c::Chains, cs::Chains...) = _cat(Val(2), c, cs...)
749749
Base.hcat(c::T, cs::T...) where {T<:Chains} = _cat(Val(2), c, cs...)
750750

751751
AbstractMCMC.chainscat(c::Chains, cs::Chains...) = _cat(Val(3), c, cs...)
752+
AbstractMCMC.chainsstack(c::AbstractVector{<:Chains}) = AbstractMCMC.chainscat(c...)
752753

753754
_cat(dim::Int, cs::Chains...) = _cat(Val(dim), cs...)
754755

@@ -822,21 +823,21 @@ function _cat(::Val{3}, c1::Chains, args::Chains...)
822823
c -> get(c.info, :start_time, nothing),
823824
vcat,
824825
args,
825-
init = get(c1.info, :start_time, nothing),
826+
init = [get(c1.info, :start_time, nothing)],
826827
)
827828
stops = mapreduce(
828829
c -> get(c.info, :stop_time, nothing),
829830
vcat,
830831
args,
831-
init = get(c1.info, :stop_time, nothing),
832+
init = [get(c1.info, :stop_time, nothing)],
832833
)
833834
# Concatenate sampler states too. This is hacky(!) but required upstream in Turing.jl
834835
# because otherwise you cannot resume multiple-chain sampling.
835836
spl_states = mapreduce(
836837
c -> get(c.info, :samplerstate, nothing),
837838
vcat,
838839
args,
839-
init = get(c1.info, :samplerstate, nothing),
840+
init = [get(c1.info, :samplerstate, nothing)],
840841
)
841842
other_props = filter(
842843
x -> !(x in [:start_time, :stop_time, :samplerstate]),

test/concatenation_tests.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ end
174174
@test chn2b.name_map == chn2.name_map
175175
@test chn2b.info == chn2.info
176176

177-
# check merging of info field
177+
# check merging of info field for multiple-chain concatenation
178178
chn = Chains(
179179
rand(10, 3, 1),
180180
["a", "b", "c"],
@@ -202,4 +202,24 @@ end
202202
@test chn3.info.samplerstate == ["state1", "state2"]
203203
# other fields should just be taken from the first chain
204204
@test chn3.info.otherinfo == "info1"
205+
206+
# for single-chain concatenation too
207+
chn = Chains(
208+
rand(10, 3, 1),
209+
["a", "b", "c"],
210+
info = (
211+
start_time = 1,
212+
stop_time = 2,
213+
samplerstate = "state1",
214+
otherinfo = "info1",
215+
),
216+
)
217+
for new_chn in [chainscat(chn), chainsstack([chn])]
218+
@test new_chn.value == chn.value
219+
@test new_chn.name_map == chn.name_map
220+
@test new_chn.info.start_time == [chn.info.start_time]
221+
@test new_chn.info.stop_time == [chn.info.stop_time]
222+
@test new_chn.info.samplerstate == [chn.info.samplerstate]
223+
@test new_chn.info.otherinfo == chn.info.otherinfo
224+
end
205225
end

0 commit comments

Comments
 (0)