Skip to content

Commit e994196

Browse files
authored
Merge branch 'master' into py/cherry-pick-0.28.5
2 parents 4fbae44 + 54691bf commit e994196

File tree

4 files changed

+78
-25
lines changed

4 files changed

+78
-25
lines changed

src/simple_varinfo.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -439,22 +439,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
439439
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
440440
end
441441

442-
function _subset(x::AbstractDict, vns)
442+
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
443443
vns_present = collect(keys(x))
444-
vns_found = mapreduce(vcat, vns) do vn
444+
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
445445
return filter(Base.Fix1(subsumes, vn), vns_present)
446446
end
447-
448-
# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
447+
C = ConstructionBase.constructorof(typeof(x))
449448
if isempty(vns_found)
450-
throw(
451-
ArgumentError(
452-
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
453-
),
454-
)
449+
return C()
450+
else
451+
return C(vn => x[vn] for vn in vns_found)
455452
end
456-
C = ConstructionBase.constructorof(typeof(x))
457-
return C(vn => x[vn] for vn in vns_found)
458453
end
459454

460455
function _subset(x::NamedTuple, vns)

src/varinfo.jl

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -368,20 +368,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
368368
)
369369
end
370370

371-
function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
371+
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName}
372372
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
373373
# For each `vn` in `vns`, get the variables subsumed by `vn`.
374-
vns = mapreduce(vcat, vns_given) do vn
374+
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
375375
filter(Base.Fix1(subsumes, vn), metadata.vns)
376376
end
377377
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
378-
indices = Dict(vn => i for (i, vn) in enumerate(vns))
378+
indices = if isempty(vns)
379+
Dict{VarName,Int}()
380+
else
381+
Dict(vn => i for (i, vn) in enumerate(vns))
382+
end
379383
# Construct new `vals` and `ranges`.
380384
vals_original = metadata.vals
381385
ranges_original = metadata.ranges
382386
# Allocate the new `vals`. and `ranges`.
383-
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
384-
ranges = similar(ranges_original)
387+
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0))
388+
ranges = similar(ranges_original, length(vns))
385389
# The new range `r` for `vns[i]` is offset by `offset` and
386390
# has the same length as the original range `r_original`.
387391
# The new `indices` (from above) ensures ordering according to `vns`.
@@ -415,7 +419,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
415419
ranges,
416420
vals,
417421
metadata.dists[indices_for_vns],
418-
metadata.gids,
422+
metadata.gids[indices_for_vns],
419423
metadata.orders[indices_for_vns],
420424
flags,
421425
)
@@ -490,7 +494,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
490494
ranges = Vector{UnitRange{Int}}()
491495
vals = T[]
492496
dists = D[]
493-
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
497+
gids = Set{Selector}[]
494498
orders = Int[]
495499
flags = Dict{String,BitVector}()
496500
# Initialize the `flags`.
@@ -520,6 +524,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
520524
dist_right = getdist(metadata_right, vn)
521525
# Give precedence to `metadata_right`.
522526
push!(dists, dist_right)
527+
gid = metadata_right.gids[getidx(metadata_right, vn)]
528+
push!(gids, gid)
523529
# `orders`: giving precedence to `metadata_right`
524530
push!(orders, getorder(metadata_right, vn))
525531
# `flags`
@@ -539,6 +545,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
539545
# `dists`
540546
dist_left = getdist(metadata_left, vn)
541547
push!(dists, dist_left)
548+
gid = metadata_left.gids[getidx(metadata_left, vn)]
549+
push!(gids, gid)
542550
# `orders`
543551
push!(orders, getorder(metadata_left, vn))
544552
# `flags`
@@ -557,6 +565,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
557565
# `dists`
558566
dist_right = getdist(metadata_right, vn)
559567
push!(dists, dist_right)
568+
gid = metadata_right.gids[getidx(metadata_right, vn)]
569+
push!(gids, gid)
560570
# `orders`
561571
push!(orders, getorder(metadata_right, vn))
562572
# `flags`
@@ -1826,14 +1836,31 @@ function BangBang.push!!(
18261836
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
18271837
)
18281838
if vi isa UntypedVarInfo
1829-
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
1839+
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
18301840
elseif vi isa TypedVarInfo
1831-
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1841+
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
1842+
end
1843+
1844+
sym = getsym(vn)
1845+
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
1846+
# The NamedTuple doesn't have an entry for this variable, let's add one.
1847+
val = tovec(r)
1848+
md = Metadata(
1849+
Dict(vn => 1),
1850+
[vn],
1851+
[1:length(val)],
1852+
val,
1853+
[dist],
1854+
[gidset],
1855+
[get_num_produce(vi)],
1856+
Dict{String,BitVector}("trans" => [false], "del" => [false]),
1857+
)
1858+
vi = Accessors.@set vi.metadata[sym] = md
1859+
else
1860+
meta = getmetadata(vi, vn)
1861+
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
18321862
end
18331863

1834-
meta = getmetadata(vi, vn)
1835-
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
1836-
18371864
return vi
18381865
end
18391866

@@ -1864,7 +1891,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
18641891
push!(meta.orders, num_produce)
18651892
push!(meta.flags["del"], false)
18661893
push!(meta.flags["trans"], false)
1867-
18681894
return meta
18691895
end
18701896

test/turing/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1111

1212
[compat]
1313
Distributions = "0.25"
14-
DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29"
14+
DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30"
1515
HypothesisTests = "0.11"
1616
MCMCChains = "6"
1717
ReverseDiff = "1.15"

test/varinfo.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
154154
test_varinfo!(vi)
155155
test_varinfo!(empty!!(TypedVarInfo(vi)))
156156
end
157+
158+
@testset "push!! to TypedVarInfo" begin
159+
vn_x = @varname x
160+
vn_y = @varname y
161+
untyped_vi = VarInfo()
162+
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
163+
typed_vi = TypedVarInfo(untyped_vi)
164+
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
165+
@test typed_vi[vn_x] == 1.0
166+
@test typed_vi[vn_y] == 2.0
167+
end
168+
157169
@testset "setgid!" begin
158170
vi = VarInfo(DynamicPPL.Metadata())
159171
meta = vi.metadata
@@ -566,6 +578,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
566578
else
567579
vns_supported_standard
568580
end
581+
582+
@testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in
583+
vns_supported
584+
varinfo_subset = subset(varinfo, VarName[])
585+
@test isempty(varinfo_subset)
586+
end
587+
569588
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
570589
vns_supported
571590
varinfo_subset = subset(varinfo, vns_subset)
@@ -694,6 +713,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
694713
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
695714
@test DynamicPPL.istrans(varinfo_merged, @varname(x))
696715
end
716+
717+
# The below used to error, testing to avoid regression.
718+
@testset "merge gids" begin
719+
gidset_left = Set([Selector(1)])
720+
vi_left = VarInfo()
721+
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
722+
gidset_right = Set([Selector(2)])
723+
vi_right = VarInfo()
724+
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
725+
varinfo_merged = merge(vi_left, vi_right)
726+
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
727+
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
728+
end
697729
end
698730

699731
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)