Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.35.0"
version = "0.35.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
7 changes: 1 addition & 6 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,7 @@ has_varnamedvector(vi::AbstractVarInfo) = false

Subset a `varinfo` to only contain the variables `vns`.

!!! warning
The ordering of the variables in the resulting `varinfo` is _not_
guaranteed to follow the ordering of the variables in `varinfo`.
Hence care must be taken, in particular when used in conjunction with
other methods which uses the vector-representation of the `varinfo`,
e.g. `getindex(varinfo, sampler)`.
The ordering of variables in the return value will be the same as in `varinfo`.

# Examples
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
Expand Down
9 changes: 5 additions & 4 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,9 @@ end

function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
vns_present = collect(keys(x))
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
return filter(Base.Fix1(subsumes, vn), vns_present)
end
vns_found = filter(
vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present
)
C = ConstructionBase.constructorof(typeof(x))
if isempty(vns_found)
return C()
Expand All @@ -439,7 +439,8 @@ function _subset(x::NamedTuple, vns)
end

syms = map(getsym, vns)
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms)))
x_syms = filter(Base.Fix2(in, syms), keys(x))
return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms)))
end

_subset(x::VarNamedVector, vns) = subset(x, vns)
Expand Down
48 changes: 23 additions & 25 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,43 +326,41 @@
_tail(nt::NamedTuple) = Base.tail(nt)
end

function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName})
function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
metadata = subset(varinfo.metadata, vns)
return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce))
end

function subset(varinfo::VectorVarInfo, vns::AbstractVector{<:VarName})
metadata = subset(varinfo.metadata, vns)
return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce))
function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
vns_syms = Set(unique(map(getsym, vns)))
syms = filter(Base.Fix2(in, vns_syms), keys(metadata))
metadatas = map(syms) do sym
subset(getfield(metadata, sym), filter(==(sym) getsym, vns))
end
return NamedTuple{syms}(metadatas)
end

function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym}
# If all the variables are using the same symbol, then we can just extract that field from the metadata.
metadata = subset(getfield(varinfo.metadata, sym), vns)
return VarInfo(
NamedTuple{(sym,)}(tuple(metadata)),
deepcopy(varinfo.logp),
deepcopy(varinfo.num_produce),
)
end
# The above method is type unstable since we don't know which symbols are in `vns`.
# In the below special case, when all `vns` have the same symbol, we can write a type stable
# version.

function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
syms = Tuple(unique(map(getsym, vns)))
metadatas = map(syms) do sym
subset(getfield(varinfo.metadata, sym), filter(==(sym) getsym, vns))
@generated function subset(
metadata::NamedTuple{names}, vns::AbstractVector{<:VarName{sym}}
) where {names,sym}
return if (sym in names)
# TODO(mhauru) Note that this could still generate an empty metadata object if none
# of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. But maybe we could add some isempty check where we use subset?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would be preferable, but I don't know how to do it in a type stable manner. Specifically, I would like to implement it so that if we get a TypedVarInfo where for some symbol :dada the corresponding Metadata object is empty, we would drop :dada from the NamedTuple. But that means that the type of the return value depends on the value of the input (and not just the type of the input), since the keys of a NamedTuple are a part of its type.

# emptiness would make this type unstable again.
:((; $sym=subset(metadata.$sym, vns)))
else
:(NamedTuple{}())

Check warning on line 356 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L356

Added line #L356 was not covered by tests
end

return VarInfo(
NamedTuple{syms}(metadatas), deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)
)
end

function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName}
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
# For each `vn` in `vns`, get the variables subsumed by `vn`.
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
filter(Base.Fix1(subsumes, vn), metadata.vns)
end
# Find all the vns in metadata that are subsumed by one of the given vns.
vns = filter(vn -> any(subsumes(vn_given, vn) for vn_given in vns_given), metadata.vns)
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
indices = if isempty(vns)
Dict{VarName,Int}()
Expand Down
15 changes: 8 additions & 7 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,8 @@ Return a new `VarNamedVector` containing the values from `vnv` for variables in
Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning
that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`.

Preserves the order of variables in `vnv`.

# Examples

```jldoctest varnamedvector-subset
Expand All @@ -1151,18 +1153,17 @@ true
julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0])
true
"""
function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:VarName}
function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName})
# NOTE: This does not specialize types when possible.
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
filter(Base.Fix1(subsumes, vn), vnv.varnames)
end
vnv_new = similar(vnv)
# Return early if possible.
isempty(vnv) && return vnv_new

for vn in vns
insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn))
settrans!(vnv_new, istrans(vnv, vn), vn)
for vn in vnv.varnames
if any(subsumes(vn_given, vn) for vn_given in vns_given)
insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn))
settrans!(vnv_new, istrans(vnv, vn), vn)
end
end

return vnv_new
Expand Down
10 changes: 10 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,16 @@ end
# Values should be the same.
@test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns]
end

@testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you feel confident that all the methods we have for VarInfo will work correctly with "empty" NamedTuple entries in VarInfo? 👀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, is there more to be tested here or do we feel confident about this converage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you feel confident that all the methods we have for VarInfo will work correctly with "empty" NamedTuple entries in VarInfo? 👀

Confident, nope. Hopeful at best. I think this PR makes the situation marginally better though, in that it is a bit harder to create VarInfos with empty Metadatas using subset now. We should do all sorts of testing around such pathological VarInfos, but I think that would be a different PR.

I was about to say that delete! probably creates similar issues, but actually, seems delete! is broken/unimplemented for TypedVarInfo anyway. Should fix that too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definitively think this PR is making the situation better:) I was mainly trying to elude to whether we should add a few more test cases here / add an "empty" varinfo to be tested somewhere else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue about it here: #839

vns_supported
varinfo_subset = subset(varinfo, vns_subset)
vns_subset_reversed = reverse(vns_subset)
varinfo_subset_reversed = subset(varinfo, vns_subset_reversed)
@test varinfo_subset[:] == varinfo_subset_reversed[:]
ground_truth = [varinfo[vn] for vn in vns_subset]
@test varinfo_subset[:] == ground_truth
end
end

# For certain varinfos we should have errors.
Expand Down
Loading