Skip to content

Commit 4650230

Browse files
committed
For VarInfo, fix merge and allow push!!ing new Symbols (#690)
* Fix treatment of gid in merge(::Metadata) * Allowing pushing new symbols to TypedVarInfo * Bump patch version to 0.30.1
1 parent 30c10c2 commit 4650230

File tree

3 files changed

+63
-17
lines changed

3 files changed

+63
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.28.5"
3+
version = "0.28.6"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/varinfo.jl

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
386386
ranges = Vector{UnitRange{Int}}()
387387
vals = T[]
388388
dists = D[]
389-
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
389+
gids = Set{Selector}[]
390390
orders = Int[]
391391
flags = Dict{String,BitVector}()
392392
# Initialize the `flags`.
@@ -416,6 +416,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
416416
dist_right = getdist(metadata_right, vn)
417417
# Give precedence to `metadata_right`.
418418
push!(dists, dist_right)
419+
gid = metadata_right.gids[getidx(metadata_right, vn)]
420+
push!(gids, gid)
419421
# `orders`: giving precedence to `metadata_right`
420422
push!(orders, getorder(metadata_right, vn))
421423
# `flags`
@@ -435,6 +437,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
435437
# `dists`
436438
dist_left = getdist(metadata_left, vn)
437439
push!(dists, dist_left)
440+
gid = metadata_left.gids[getidx(metadata_left, vn)]
441+
push!(gids, gid)
438442
# `orders`
439443
push!(orders, getorder(metadata_left, vn))
440444
# `flags`
@@ -453,6 +457,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
453457
# `dists`
454458
dist_right = getdist(metadata_right, vn)
455459
push!(dists, dist_right)
460+
gid = metadata_right.gids[getidx(metadata_right, vn)]
461+
push!(gids, gid)
456462
# `orders`
457463
push!(orders, getorder(metadata_right, vn))
458464
# `flags`
@@ -1598,25 +1604,40 @@ function BangBang.push!!(
15981604
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
15991605
)
16001606
if vi isa UntypedVarInfo
1601-
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
1607+
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
16021608
elseif vi isa TypedVarInfo
1603-
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1609+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
16041610
end
16051611

16061612
val = vectorize(dist, r)
1607-
1608-
meta = getmetadata(vi, vn)
1609-
meta.idcs[vn] = length(meta.idcs) + 1
1610-
push!(meta.vns, vn)
1611-
l = length(meta.vals)
1612-
n = length(val)
1613-
push!(meta.ranges, (l + 1):(l + n))
1614-
append!(meta.vals, val)
1615-
push!(meta.dists, dist)
1616-
push!(meta.gids, gidset)
1617-
push!(meta.orders, get_num_produce(vi))
1618-
push!(meta.flags["del"], false)
1619-
push!(meta.flags["trans"], false)
1613+
sym = getsym(vn)
1614+
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
1615+
# The NamedTuple doesn't have an entry for this variable, let's add one.
1616+
md = Metadata(
1617+
Dict(vn => 1),
1618+
[vn],
1619+
[1:length(val)],
1620+
val,
1621+
[dist],
1622+
[gidset],
1623+
[get_num_produce(vi)],
1624+
Dict{String,BitVector}("trans" => [false], "del" => [false]),
1625+
)
1626+
vi = Accessors.@set vi.metadata[sym] = md
1627+
else
1628+
meta = getmetadata(vi, vn)
1629+
meta.idcs[vn] = length(meta.idcs) + 1
1630+
push!(meta.vns, vn)
1631+
l = length(meta.vals)
1632+
n = length(val)
1633+
push!(meta.ranges, (l + 1):(l + n))
1634+
append!(meta.vals, val)
1635+
push!(meta.dists, dist)
1636+
push!(meta.gids, gidset)
1637+
push!(meta.orders, get_num_produce(vi))
1638+
push!(meta.flags["del"], false)
1639+
push!(meta.flags["trans"], false)
1640+
end
16201641

16211642
return vi
16221643
end

test/varinfo.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
145145
test_varinfo!(vi)
146146
test_varinfo!(empty!!(TypedVarInfo(vi)))
147147
end
148+
149+
@testset "push!! to TypedVarInfo" begin
150+
vn_x = @varname x
151+
vn_y = @varname y
152+
untyped_vi = VarInfo()
153+
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
154+
typed_vi = TypedVarInfo(untyped_vi)
155+
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
156+
@test typed_vi[vn_x] == 1.0
157+
@test typed_vi[vn_y] == 2.0
158+
end
159+
148160
@testset "setgid!" begin
149161
vi = VarInfo()
150162
meta = vi.metadata
@@ -645,6 +657,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
645657
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
646658
@test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal
647659
end
660+
661+
# The below used to error, testing to avoid regression.
662+
@testset "merge gids" begin
663+
gidset_left = Set([Selector(1)])
664+
vi_left = VarInfo()
665+
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
666+
gidset_right = Set([Selector(2)])
667+
vi_right = VarInfo()
668+
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
669+
varinfo_merged = merge(vi_left, vi_right)
670+
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
671+
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
672+
end
648673
end
649674

650675
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)