Skip to content

Commit 30c10c2

Browse files
Allow empty subsets of VarInfos (#692)
* Allow empty subsets of VarInfos * Run JuliaFormatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 0683088 commit 30c10c2

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

src/simple_varinfo.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -429,22 +429,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
429429
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
430430
end
431431

432-
function _subset(x::AbstractDict, vns)
432+
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
433433
vns_present = collect(keys(x))
434-
vns_found = mapreduce(vcat, vns) do vn
434+
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
435435
return filter(Base.Fix1(subsumes, vn), vns_present)
436436
end
437-
438-
# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
437+
C = ConstructionBase.constructorof(typeof(x))
439438
if isempty(vns_found)
440-
throw(
441-
ArgumentError(
442-
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
443-
),
444-
)
439+
return C()
440+
else
441+
return C(vn => x[vn] for vn in vns_found)
445442
end
446-
C = ConstructionBase.constructorof(typeof(x))
447-
return C(vn => x[vn] for vn in vns_found)
448443
end
449444

450445
function _subset(x::NamedTuple, vns)

src/varinfo.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
264264
return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce)
265265
end
266266

267-
function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
267+
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName}
268268
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
269269
# For each `vn` in `vns`, get the variables subsumed by `vn`.
270-
vns = mapreduce(vcat, vns_given) do vn
270+
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
271271
filter(Base.Fix1(subsumes, vn), metadata.vns)
272272
end
273273
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
274-
indices = Dict(vn => i for (i, vn) in enumerate(vns))
274+
indices = if isempty(vns)
275+
Dict{VarName,Int}()
276+
else
277+
Dict(vn => i for (i, vn) in enumerate(vns))
278+
end
275279
# Construct new `vals` and `ranges`.
276280
vals_original = metadata.vals
277281
ranges_original = metadata.ranges
278282
# Allocate the new `vals`. and `ranges`.
279-
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
280-
ranges = similar(ranges_original)
283+
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0))
284+
ranges = similar(ranges_original, length(vns))
281285
# The new range `r` for `vns[i]` is offset by `offset` and
282286
# has the same length as the original range `r_original`.
283287
# The new `indices` (from above) ensures ordering according to `vns`.
@@ -311,7 +315,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
311315
ranges,
312316
vals,
313317
metadata.dists[indices_for_vns],
314-
metadata.gids,
318+
metadata.gids[indices_for_vns],
315319
metadata.orders[indices_for_vns],
316320
flags,
317321
)

test/varinfo.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
511511
else
512512
vns_supported_standard
513513
end
514+
515+
@testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in
516+
vns_supported
517+
varinfo_subset = subset(varinfo, VarName[])
518+
@test isempty(varinfo_subset)
519+
end
520+
514521
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
515522
vns_supported
516523
varinfo_subset = subset(varinfo, vns_subset)

0 commit comments

Comments
 (0)