Skip to content

Commit de2d3ec

Browse files
authored
concatenate sampler states from chains (#488)
* concatenate sampler states from chains * Format * Add tests * Bump patch (this _is_ a bugfix, right...)
1 parent a45c09e commit de2d3ec

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "Chain types and utility functions for MCMC simulations."
6-
version = "7.2.0"
6+
version = "7.2.1"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/chains.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -830,11 +830,22 @@ function _cat(::Val{3}, c1::Chains, args::Chains...)
830830
args,
831831
init = get(c1.info, :stop_time, nothing),
832832
)
833-
nontime_props =
834-
filter(x -> !(x in [:start_time, :stop_time]), [propertynames(c1.info)...])
833+
# Concatenate sampler states too. This is hacky(!) but required upstream in Turing.jl
834+
# because otherwise you cannot resume multiple-chain sampling.
835+
spl_states = mapreduce(
836+
c -> get(c.info, :samplerstate, nothing),
837+
vcat,
838+
args,
839+
init = get(c1.info, :samplerstate, nothing),
840+
)
841+
other_props = filter(
842+
x -> !(x in [:start_time, :stop_time, :samplerstate]),
843+
[propertynames(c1.info)...],
844+
)
845+
new_info =
846+
NamedTuple{tuple(other_props...)}(tuple([c1.info[n] for n in other_props]...))
835847
new_info =
836-
NamedTuple{tuple(nontime_props...)}(tuple([c1.info[n] for n in nontime_props]...))
837-
new_info = merge(new_info, (start_time = starts, stop_time = stops))
848+
merge(new_info, (start_time = starts, stop_time = stops, samplerstate = spl_states))
838849

839850
return Chains(value, missing, c1.name_map, new_info)
840851
end

test/concatenation_tests.jl

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33

44
@testset "merge_union" begin
55
@test @inferred(MCMCChains.merge_union((a = [1], b = [3.0]), (c = [3], a = [2.5]))) ==
6-
(a = [1.0, 2.5], b = [3.0], c = [3])
6+
(a = [1.0, 2.5], b = [3.0], c = [3])
77
end
88

99
@testset "concatenation tests" begin
@@ -19,23 +19,23 @@ end
1919

2020
# Test dim 1
2121
c1_2 = cat(c1, c2; dims = 1)
22-
@test c1_2.value.data == cat(v1, v2, dims=1)
22+
@test c1_2.value.data == cat(v1, v2, dims = 1)
2323
@test range(c1_2) == 1:1:1000
2424
@test names(c1_2) == names(c1) == names(c2)
2525
@test chains(c1_2) == chains(c1) == chains(c2)
2626
@test c1_2.value == vcat(c1, c2).value
2727

2828
# Test dim 2
2929
c1_3 = cat(c1, c3; dims = 2)
30-
@test c1_3.value.data == cat(v1, v3, dims=2)
30+
@test c1_3.value.data == cat(v1, v3, dims = 2)
3131
@test range(c1_3) == 1:1:500
32-
@test names(c1_3) == cat(names(c1), names(c3), dims=1)
32+
@test names(c1_3) == cat(names(c1), names(c3), dims = 1)
3333
@test chains(c1_3) == chains(c1) == chains(c3)
3434
@test c1_3.value == hcat(c1, c3).value
3535

3636
# Test dim 3
3737
c1_4 = cat(c1, c4; dims = 3)
38-
@test c1_4.value.data == cat(v1, v4, dims=3)
38+
@test c1_4.value.data == cat(v1, v4, dims = 3)
3939
@test range(c1_4) == 1:1:500
4040
@test names(c1_4) == names(c1) == names(c4)
4141
@test length(chains(c1_4)) == length(chains(c1)) + length(chains(c4))
@@ -50,10 +50,16 @@ end
5050
@test_throws ArgumentError vcat(chn, Chains(rand(2, 5, 2)))
5151

5252
# incorrect names
53-
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"]; start=11))
53+
@test_throws ArgumentError vcat(
54+
chn,
55+
Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"]; start = 11),
56+
)
5457

5558
# incorrect number of chains
56-
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"]; start=11))
59+
@test_throws ArgumentError vcat(
60+
chn,
61+
Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"]; start = 11),
62+
)
5763

5864
# concate the same chain
5965
chn_shifted = setrange(chn, 11:20)
@@ -126,7 +132,7 @@ end
126132
@test names(chn2) == vcat(names(chn), names(chn1))
127133
@test range(chn2) == 1:10
128134
@test chn2.name_map == (parameters = [:a, :b, :e], internal = [:c, :d])
129-
135+
130136
chn2a = cat(chn, chn1; dims = Val(2))
131137
@test chn2a.value == chn2.value
132138
@test chn2a.name_map == chn2.name_map
@@ -157,7 +163,7 @@ end
157163
@test range(chn2) == 1:10
158164
# just keep the name map of the first argument
159165
@test chn2.name_map == (parameters = [:a, :b], internal = [:c])
160-
166+
161167
chn2a = cat(chn, chn1; dims = Val(3))
162168
@test chn2a.value == chn2.value
163169
@test chn2a.name_map == chn2.name_map
@@ -167,4 +173,33 @@ end
167173
@test chn2b.value == chn2.value
168174
@test chn2b.name_map == chn2.name_map
169175
@test chn2b.info == chn2.info
170-
end
176+
177+
# check merging of info field
178+
chn = Chains(
179+
rand(10, 3, 1),
180+
["a", "b", "c"],
181+
info = (
182+
start_time = 1,
183+
stop_time = 2,
184+
samplerstate = "state1",
185+
otherinfo = "info1",
186+
),
187+
)
188+
chn1 = Chains(
189+
rand(10, 3, 1),
190+
["a", "b", "c"],
191+
info = (
192+
start_time = 3,
193+
stop_time = 4,
194+
samplerstate = "state2",
195+
otherinfo = "info2",
196+
),
197+
)
198+
chn3 = chainscat(chn, chn1)
199+
# these three fields should be concatenated
200+
@test chn3.info.start_time == [1, 3]
201+
@test chn3.info.stop_time == [2, 4]
202+
@test chn3.info.samplerstate == ["state1", "state2"]
203+
# other fields should just be taken from the first chain
204+
@test chn3.info.otherinfo == "info1"
205+
end

0 commit comments

Comments
 (0)