Skip to content

Commit 06fbe70

Browse files
committed
Allow empty subsets of VarInfos
1 parent 1d10278 commit 06fbe70

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.30"
3+
version = "0.30.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/simple_varinfo.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -439,22 +439,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
439439
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
440440
end
441441

442-
function _subset(x::AbstractDict, vns)
442+
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
443443
vns_present = collect(keys(x))
444-
vns_found = mapreduce(vcat, vns) do vn
444+
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
445445
return filter(Base.Fix1(subsumes, vn), vns_present)
446446
end
447-
448-
# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
447+
C = ConstructionBase.constructorof(typeof(x))
449448
if isempty(vns_found)
450-
throw(
451-
ArgumentError(
452-
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
453-
),
454-
)
449+
return C()
450+
else
451+
return C(vn => x[vn] for vn in vns_found)
455452
end
456-
C = ConstructionBase.constructorof(typeof(x))
457-
return C(vn => x[vn] for vn in vns_found)
458453
end
459454

460455
function _subset(x::NamedTuple, vns)

src/varinfo.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,20 +368,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
368368
)
369369
end
370370

371-
function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
371+
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where VN<:VarName
372372
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
373373
# For each `vn` in `vns`, get the variables subsumed by `vn`.
374-
vns = mapreduce(vcat, vns_given) do vn
374+
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
375375
filter(Base.Fix1(subsumes, vn), metadata.vns)
376376
end
377377
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
378-
indices = Dict(vn => i for (i, vn) in enumerate(vns))
378+
indices = if isempty(vns)
379+
Dict{VarName,Int}()
380+
else
381+
Dict(vn => i for (i, vn) in enumerate(vns))
382+
end
379383
# Construct new `vals` and `ranges`.
380384
vals_original = metadata.vals
381385
ranges_original = metadata.ranges
382386
# Allocate the new `vals`. and `ranges`.
383-
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
384-
ranges = similar(ranges_original)
387+
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0))
388+
ranges = similar(ranges_original, length(vns))
385389
# The new range `r` for `vns[i]` is offset by `offset` and
386390
# has the same length as the original range `r_original`.
387391
# The new `indices` (from above) ensures ordering according to `vns`.
@@ -415,7 +419,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
415419
ranges,
416420
vals,
417421
metadata.dists[indices_for_vns],
418-
metadata.gids,
422+
metadata.gids[indices_for_vns],
419423
metadata.orders[indices_for_vns],
420424
flags,
421425
)

test/varinfo.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
566566
else
567567
vns_supported_standard
568568
end
569+
570+
@testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in vns_supported
571+
varinfo_subset = subset(varinfo, VarName[])
572+
@test isempty(varinfo_subset)
573+
end
574+
569575
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
570576
vns_supported
571577
varinfo_subset = subset(varinfo, vns_subset)

0 commit comments

Comments
 (0)