diff --git a/Project.toml b/Project.toml index 252e2e2da..5e495f92e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "7.2.2" +version = "7.3.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index 0592a5014..141cade9b 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -3,7 +3,7 @@ module MCMCChains using AxisArrays const axes = Base.axes import AbstractMCMC -import AbstractMCMC: chainscat +import AbstractMCMC: chainscat, chainsstack using Distributions using RecipesBase using Dates @@ -35,7 +35,7 @@ import LinearAlgebra import Random import Statistics: std, cor, mean, var, mean! -export Chains, chains, chainscat +export Chains, chains, chainscat, chainsstack export setrange, resetrange export set_section, get_params, sections, sort_sections, setinfo export replacenames, namesingroup, group diff --git a/src/chains.jl b/src/chains.jl index ee01675bf..39eeb7186 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -749,6 +749,7 @@ Base.hcat(c::Chains, cs::Chains...) = _cat(Val(2), c, cs...) Base.hcat(c::T, cs::T...) where {T<:Chains} = _cat(Val(2), c, cs...) AbstractMCMC.chainscat(c::Chains, cs::Chains...) = _cat(Val(3), c, cs...) +AbstractMCMC.chainsstack(c::AbstractVector{<:Chains}) = AbstractMCMC.chainscat(c...) _cat(dim::Int, cs::Chains...) = _cat(Val(dim), cs...) @@ -822,13 +823,13 @@ function _cat(::Val{3}, c1::Chains, args::Chains...) c -> get(c.info, :start_time, nothing), vcat, args, - init = get(c1.info, :start_time, nothing), + init = [get(c1.info, :start_time, nothing)], ) stops = mapreduce( c -> get(c.info, :stop_time, nothing), vcat, args, - init = get(c1.info, :stop_time, nothing), + init = [get(c1.info, :stop_time, nothing)], ) # Concatenate sampler states too. This is hacky(!) but required upstream in Turing.jl # because otherwise you cannot resume multiple-chain sampling. @@ -836,7 +837,7 @@ function _cat(::Val{3}, c1::Chains, args::Chains...) c -> get(c.info, :samplerstate, nothing), vcat, args, - init = get(c1.info, :samplerstate, nothing), + init = [get(c1.info, :samplerstate, nothing)], ) other_props = filter( x -> !(x in [:start_time, :stop_time, :samplerstate]), diff --git a/test/concatenation_tests.jl b/test/concatenation_tests.jl index 2b2b6772f..0fe3166b7 100644 --- a/test/concatenation_tests.jl +++ b/test/concatenation_tests.jl @@ -174,7 +174,7 @@ end @test chn2b.name_map == chn2.name_map @test chn2b.info == chn2.info - # check merging of info field + # check merging of info field for multiple-chain concatenation chn = Chains( rand(10, 3, 1), ["a", "b", "c"], @@ -202,4 +202,24 @@ end @test chn3.info.samplerstate == ["state1", "state2"] # other fields should just be taken from the first chain @test chn3.info.otherinfo == "info1" + + # for single-chain concatenation too + chn = Chains( + rand(10, 3, 1), + ["a", "b", "c"], + info = ( + start_time = 1, + stop_time = 2, + samplerstate = "state1", + otherinfo = "info1", + ), + ) + for new_chn in [chainscat(chn), chainsstack([chn])] + @test new_chn.value == chn.value + @test new_chn.name_map == chn.name_map + @test new_chn.info.start_time == [chn.info.start_time] + @test new_chn.info.stop_time == [chn.info.stop_time] + @test new_chn.info.samplerstate == [chn.info.samplerstate] + @test new_chn.info.otherinfo == chn.info.otherinfo + end end