Skip to content

Commit df41420

Browse files
committed
Fix treatment of gid in merge(::Metadata)
1 parent 24a7380 commit df41420

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/varinfo.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
392392
ranges = Vector{UnitRange{Int}}()
393393
vals = T[]
394394
dists = D[]
395-
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
395+
gids = Set{Selector}[]
396396
orders = Int[]
397397
flags = Dict{String,BitVector}()
398398
# Initialize the `flags`.
@@ -422,6 +422,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
422422
dist_right = getdist(metadata_right, vn)
423423
# Give precedence to `metadata_right`.
424424
push!(dists, dist_right)
425+
gid = metadata_right.gids[getidx(metadata_right, vn)]
426+
push!(gids, gid)
425427
# `orders`: giving precedence to `metadata_right`
426428
push!(orders, getorder(metadata_right, vn))
427429
# `flags`
@@ -441,6 +443,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
441443
# `dists`
442444
dist_left = getdist(metadata_left, vn)
443445
push!(dists, dist_left)
446+
gid = metadata_left.gids[getidx(metadata_left, vn)]
447+
push!(gids, gid)
444448
# `orders`
445449
push!(orders, getorder(metadata_left, vn))
446450
# `flags`
@@ -459,6 +463,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
459463
# `dists`
460464
dist_right = getdist(metadata_right, vn)
461465
push!(dists, dist_right)
466+
gid = metadata_right.gids[getidx(metadata_right, vn)]
467+
push!(gids, gid)
462468
# `orders`
463469
push!(orders, getorder(metadata_right, vn))
464470
# `flags`

test/varinfo.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
645645
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
646646
@test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal
647647
end
648+
649+
# The below used to error, testing to avoid regression.
650+
@testset "merge gids" begin
651+
gidset_left = Set([Selector(1)])
652+
vi_left = VarInfo()
653+
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
654+
gidset_right = Set([Selector(2)])
655+
vi_right = VarInfo()
656+
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
657+
varinfo_merged = merge(vi_left, vi_right)
658+
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
659+
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
660+
end
648661
end
649662

650663
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)