Skip to content

Commit 9f5266a

Browse files
authored
Use subsumes in subset (#587)
* use `subsumes` in `subset` to allow more flexibility in subsetting varinfos * fixed bug in `subset` for `AbstractDict` * fixed bug where `subset` wasn't properly tested on `SimpleVarInfo`
1 parent 6a2454f commit 9f5266a

File tree

4 files changed

+53
-25
lines changed

4 files changed

+53
-25
lines changed

src/simple_varinfo.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -430,18 +430,21 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
430430
end
431431

432432
function _subset(x::AbstractDict, vns)
433-
# NOTE: This requires `vns` to be explicitly present in `x`.
434-
if any(!Base.Fix1(haskey, x), vns)
433+
vns_present = collect(keys(x))
434+
vns_found = mapreduce(vcat, vns) do vn
435+
return filter(Base.Fix1(subsumes, vn), vns_present)
436+
end
437+
438+
# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
439+
if isempty(vns_found)
435440
throw(
436441
ArgumentError(
437-
"Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " *
438-
"For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " *
439-
"`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.",
442+
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
440443
),
441444
)
442445
end
443446
C = ConstructionBase.constructorof(typeof(x))
444-
return C(vn => x[vn] for vn in vns)
447+
return C(vn => x[vn] for vn in vns_found)
445448
end
446449

447450
function _subset(x::NamedTuple, vns)
@@ -456,7 +459,7 @@ function _subset(x::NamedTuple, vns)
456459
end
457460

458461
syms = map(getsym, vns)
459-
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms)))
462+
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms)))
460463
end
461464

462465
# `merge`

src/threadsafe.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo)
200200
return resetlogp!!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo)))
201201
end
202202

203+
values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
203204
values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)
204205

205206
function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String)

src/varinfo.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,12 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
264264
return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce)
265265
end
266266

267-
function subset(metadata::Metadata, vns::AbstractVector{<:VarName})
267+
function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
268268
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
269+
# For each `vn` in `vns`, get the variables subsumed by `vn`.
270+
vns = mapreduce(vcat, vns_given) do vn
271+
filter(Base.Fix1(subsumes, vn), metadata.vns)
272+
end
269273
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
270274
indices = Dict(vn => i for (i, vn) in enumerate(vns))
271275
# Construct new `vals` and `ranges`.

test/varinfo.jl

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -483,23 +483,34 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
483483
[@varname(s), @varname(m), @varname(x[2])],
484484
[@varname(s), @varname(x[1]), @varname(x[2])],
485485
[@varname(m), @varname(x[1]), @varname(x[2])],
486-
[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])],
487486
]
488487

489-
# `SimpleaVarInfo` only supports subsetting using the varnames as they appear
488+
# Patterns requiring `subsumes`.
489+
vns_supported_with_subsumes = [
490+
[@varname(s), @varname(x)] => [@varname(s), @varname(x[1]), @varname(x[2])],
491+
[@varname(m), @varname(x)] => [@varname(m), @varname(x[1]), @varname(x[2])],
492+
[@varname(s), @varname(m), @varname(x)] =>
493+
[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])],
494+
]
495+
496+
# `SimpleVarInfo` only supports subsetting using the varnames as they appear
490497
# in the model.
491498
vns_supported_simple = filter((vns), vns_supported_standard)
492499

493-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos_standard
500+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
494501
# All variables.
495502
check_varinfo_keys(varinfo, vns)
496503

497504
# Added a `convert` to make the naming of the testsets a bit more readable.
498-
vns_supported = if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple
499-
vns_supported_simple
500-
else
501-
vns_supported_standard
502-
end
505+
# `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames,
506+
## i.e. `VarName{sym}()` without any indexing, etc.
507+
vns_supported =
508+
if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple &&
509+
values_as(varinfo) isa NamedTuple
510+
vns_supported_simple
511+
else
512+
vns_supported_standard
513+
end
503514
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
504515
vns_supported
505516
varinfo_subset = subset(varinfo, vns_subset)
@@ -516,6 +527,24 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
516527
# Values should be the same.
517528
@test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns]
518529
end
530+
531+
@testset "$(convert(Vector{VarName}, vns_subset))" for (
532+
vns_subset, vns_target
533+
) in vns_supported_with_subsumes
534+
varinfo_subset = subset(varinfo, vns_subset)
535+
# Should now only contain the variables in `vns_subset`.
536+
check_varinfo_keys(varinfo_subset, vns_target)
537+
# Values should be the same.
538+
@test [varinfo_subset[vn] for vn in vns_target] == [varinfo[vn] for vn in vns_target]
539+
540+
# `merge` with the original.
541+
varinfo_merged = merge(varinfo, varinfo_subset)
542+
vns_merged = keys(varinfo_merged)
543+
# Should be equivalent.
544+
check_varinfo_keys(varinfo_merged, vns)
545+
# Values should be the same.
546+
@test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns]
547+
end
519548
end
520549

521550
# For certain varinfos we should have errors.
@@ -526,15 +555,6 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
526555
varinfo, [@varname(s), @varname(m), @varname(x[1])]
527556
)
528557
end
529-
# `SimpleVarInfo{<:AbstractDict}` can only handle varnames as they appear in the model.
530-
varinfo = varinfos[findfirst(
531-
Base.Fix2(isa, SimpleVarInfo{<:AbstractDict}), varinfos
532-
)]
533-
@testset "$(short_varinfo_name(varinfo)): failure cases" begin
534-
@test_throws ArgumentError subset(
535-
varinfo, [@varname(s), @varname(m), @varname(x)]
536-
)
537-
end
538558
end
539559

540560
@testset "merge" begin

0 commit comments

Comments
 (0)