Skip to content

Commit bd4baf1

Browse files
committed
Fix treatment of gid in merge(::Metadata)
1 parent 1d10278 commit bd4baf1

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/varinfo.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
490490
ranges = Vector{UnitRange{Int}}()
491491
vals = T[]
492492
dists = D[]
493-
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
493+
gids = Set{Selector}[]
494494
orders = Int[]
495495
flags = Dict{String,BitVector}()
496496
# Initialize the `flags`.
@@ -520,6 +520,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
520520
dist_right = getdist(metadata_right, vn)
521521
# Give precedence to `metadata_right`.
522522
push!(dists, dist_right)
523+
gid = metadata_right.gids[getidx(metadata_right, vn)]
524+
push!(gids, gid)
523525
# `orders`: giving precedence to `metadata_right`
524526
push!(orders, getorder(metadata_right, vn))
525527
# `flags`
@@ -539,6 +541,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
539541
# `dists`
540542
dist_left = getdist(metadata_left, vn)
541543
push!(dists, dist_left)
544+
gid = metadata_left.gids[getidx(metadata_left, vn)]
545+
push!(gids, gid)
542546
# `orders`
543547
push!(orders, getorder(metadata_left, vn))
544548
# `flags`
@@ -557,6 +561,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
557561
# `dists`
558562
dist_right = getdist(metadata_right, vn)
559563
push!(dists, dist_right)
564+
gid = metadata_right.gids[getidx(metadata_right, vn)]
565+
push!(gids, gid)
560566
# `orders`
561567
push!(orders, getorder(metadata_right, vn))
562568
# `flags`

test/varinfo.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
694694
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
695695
@test DynamicPPL.istrans(varinfo_merged, @varname(x))
696696
end
697+
698+
# The below used to error, testing to avoid regression.
699+
@testset "merge gids" begin
700+
gidset_left = Set([Selector(1)])
701+
vi_left = VarInfo()
702+
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
703+
gidset_right = Set([Selector(2)])
704+
vi_right = VarInfo()
705+
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
706+
varinfo_merged = merge(vi_left, vi_right)
707+
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
708+
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
709+
end
697710
end
698711

699712
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)