Skip to content

Commit d804ef1

Browse files
committed
Allowing pushing new symbols to TypedVarInfo
1 parent bd4baf1 commit d804ef1

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

src/varinfo.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,14 +1832,31 @@ function BangBang.push!!(
18321832
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
18331833
)
18341834
if vi isa UntypedVarInfo
1835-
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
1835+
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
18361836
elseif vi isa TypedVarInfo
1837-
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1837+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1838+
end
1839+
1840+
sym = getsym(vn)
1841+
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
1842+
# The NamedTuple doesn't have an entry for this variable, let's add one.
1843+
val = tovec(r)
1844+
md = Metadata(
1845+
Dict(vn => 1),
1846+
[vn],
1847+
[1:length(val)],
1848+
val,
1849+
[dist],
1850+
[gidset],
1851+
[get_num_produce(vi)],
1852+
Dict{String,BitVector}("trans" => [false], "del" => [false]),
1853+
)
1854+
vi = Accessors.@set vi.metadata[sym] = md
1855+
else
1856+
meta = getmetadata(vi, vn)
1857+
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
18381858
end
18391859

1840-
meta = getmetadata(vi, vn)
1841-
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
1842-
18431860
return vi
18441861
end
18451862

@@ -1870,7 +1887,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
18701887
push!(meta.orders, num_produce)
18711888
push!(meta.flags["del"], false)
18721889
push!(meta.flags["trans"], false)
1873-
18741890
return meta
18751891
end
18761892

test/varinfo.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
154154
test_varinfo!(vi)
155155
test_varinfo!(empty!!(TypedVarInfo(vi)))
156156
end
157+
158+
@testset "push!! to TypedVarInfo" begin
159+
vn_x = @varname x
160+
vn_y = @varname y
161+
untyped_vi = VarInfo()
162+
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
163+
typed_vi = TypedVarInfo(untyped_vi)
164+
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
165+
@test typed_vi[vn_x] == 1.0
166+
@test typed_vi[vn_y] == 2.0
167+
end
168+
157169
@testset "setgid!" begin
158170
vi = VarInfo(DynamicPPL.Metadata())
159171
meta = vi.metadata

0 commit comments

Comments
 (0)