diff --git a/Project.toml b/Project.toml index a9463a821..863ae9e8b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.0" +version = "0.35.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 66b098370..aa4c3f98d 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -331,12 +331,7 @@ has_varnamedvector(vi::AbstractVarInfo) = false Subset a `varinfo` to only contain the variables `vns`. -!!! warning - The ordering of the variables in the resulting `varinfo` is _not_ - guaranteed to follow the ordering of the variables in `varinfo`. - Hence care must be taken, in particular when used in conjunction with - other methods which uses the vector-representation of the `varinfo`, - e.g. `getindex(varinfo, sampler)`. +The ordering of variables in the return value will be the same as in `varinfo`. # Examples ```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 173eaa9e1..a49213642 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -416,9 +416,9 @@ end function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} vns_present = collect(keys(x)) - vns_found = mapreduce(vcat, vns; init=VN[]) do vn - return filter(Base.Fix1(subsumes, vn), vns_present) - end + vns_found = filter( + vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present + ) C = ConstructionBase.constructorof(typeof(x)) if isempty(vns_found) return C() @@ -439,7 +439,8 @@ function _subset(x::NamedTuple, vns) end syms = map(getsym, vns) - return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms))) + x_syms = filter(Base.Fix2(in, syms), keys(x)) + return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms))) end _subset(x::VarNamedVector, vns) = subset(x, vns) diff --git a/src/varinfo.jl b/src/varinfo.jl index ca143ea63..f70582428 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -326,43 +326,41 @@ else _tail(nt::NamedTuple) = Base.tail(nt) end -function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) +function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) end -function subset(varinfo::VectorVarInfo, vns::AbstractVector{<:VarName}) - metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) +function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) + vns_syms = Set(unique(map(getsym, vns))) + syms = filter(Base.Fix2(in, vns_syms), keys(metadata)) + metadatas = map(syms) do sym + subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns)) + end + return NamedTuple{syms}(metadatas) end -function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} - # If all the variables are using the same symbol, then we can just extract that field from the metadata. - metadata = subset(getfield(varinfo.metadata, sym), vns) - return VarInfo( - NamedTuple{(sym,)}(tuple(metadata)), - deepcopy(varinfo.logp), - deepcopy(varinfo.num_produce), - ) -end +# The above method is type unstable since we don't know which symbols are in `vns`. +# In the below special case, when all `vns` have the same symbol, we can write a type stable +# version. -function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) - syms = Tuple(unique(map(getsym, vns))) - metadatas = map(syms) do sym - subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) +@generated function subset( + metadata::NamedTuple{names}, vns::AbstractVector{<:VarName{sym}} +) where {names,sym} + return if (sym in names) + # TODO(mhauru) Note that this could still generate an empty metadata object if none + # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for + # emptiness would make this type unstable again. + :((; $sym=subset(metadata.$sym, vns))) + else + :(NamedTuple{}()) end - - return VarInfo( - NamedTuple{syms}(metadatas), deepcopy(varinfo.logp), deepcopy(varinfo.num_produce) - ) end function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName} # TODO: Should we error if `vns` contains a variable that is not in `metadata`? - # For each `vn` in `vns`, get the variables subsumed by `vn`. - vns = mapreduce(vcat, vns_given; init=VN[]) do vn - filter(Base.Fix1(subsumes, vn), metadata.vns) - end + # Find all the vns in metadata that are subsumed by one of the given vns. + vns = filter(vn -> any(subsumes(vn_given, vn) for vn_given in vns_given), metadata.vns) indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) indices = if isempty(vns) Dict{VarName,Int}() diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 6b7c82859..965db96d5 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1138,6 +1138,8 @@ Return a new `VarNamedVector` containing the values from `vnv` for variables in Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`. +Preserves the order of variables in `vnv`. + # Examples ```jldoctest varnamedvector-subset @@ -1151,18 +1153,17 @@ true julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) true """ -function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:VarName} +function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) # NOTE: This does not specialize types when possible. - vns = mapreduce(vcat, vns_given; init=VN[]) do vn - filter(Base.Fix1(subsumes, vn), vnv.varnames) - end vnv_new = similar(vnv) # Return early if possible. isempty(vnv) && return vnv_new - for vn in vns - insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) - settrans!(vnv_new, istrans(vnv, vn), vn) + for vn in vnv.varnames + if any(subsumes(vn_given, vn) for vn_given in vns_given) + insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) + settrans!(vnv_new, istrans(vnv, vn), vn) + end end return vnv_new diff --git a/test/varinfo.jl b/test/varinfo.jl index d689a1bf4..80eb05480 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -729,6 +729,16 @@ end # Values should be the same. @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end + + @testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in + vns_supported + varinfo_subset = subset(varinfo, vns_subset) + vns_subset_reversed = reverse(vns_subset) + varinfo_subset_reversed = subset(varinfo, vns_subset_reversed) + @test varinfo_subset[:] == varinfo_subset_reversed[:] + ground_truth = [varinfo[vn] for vn in vns_subset] + @test varinfo_subset[:] == ground_truth + end end # For certain varinfos we should have errors.