Skip to content

Commit dfa6338

Browse files
committed
Allowing pushing new symbols to TypedVarInfo
1 parent df41420 commit dfa6338

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

src/varinfo.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,13 +1624,30 @@ function BangBang.push!!(
16241624
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
16251625
)
16261626
if vi isa UntypedVarInfo
1627-
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
1627+
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
16281628
elseif vi isa TypedVarInfo
1629-
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1629+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
16301630
end
16311631

1632-
meta = getmetadata(vi, vn)
1633-
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
1632+
sym = getsym(vn)
1633+
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
1634+
# The NamedTuple doesn't have an entry for this variable, let's add one.
1635+
val = tovec(r)
1636+
md = Metadata(
1637+
Dict(vn => 1),
1638+
[vn],
1639+
[1:length(val)],
1640+
val,
1641+
[dist],
1642+
[gidset],
1643+
[get_num_produce(vi)],
1644+
Dict{String,BitVector}("trans" => [false], "del" => [false]),
1645+
)
1646+
vi = Accessors.@set vi.metadata[sym] = md
1647+
else
1648+
meta = getmetadata(vi, vn)
1649+
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
1650+
end
16341651

16351652
return vi
16361653
end
@@ -1648,7 +1665,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
16481665
push!(meta.orders, num_produce)
16491666
push!(meta.flags["del"], false)
16501667
push!(meta.flags["trans"], false)
1651-
16521668
return meta
16531669
end
16541670

test/varinfo.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
145145
test_varinfo!(vi)
146146
test_varinfo!(empty!!(TypedVarInfo(vi)))
147147
end
148+
149+
@testset "push!! to TypedVarInfo" begin
150+
vn_x = @varname x
151+
vn_y = @varname y
152+
untyped_vi = VarInfo()
153+
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
154+
typed_vi = TypedVarInfo(untyped_vi)
155+
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
156+
@test typed_vi[vn_x] == 1.0
157+
@test typed_vi[vn_y] == 2.0
158+
end
159+
148160
@testset "setgid!" begin
149161
vi = VarInfo()
150162
meta = vi.metadata

0 commit comments

Comments
 (0)