-
Notifications
You must be signed in to change notification settings - Fork 37
Improvements to VarNamedVector #1098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
150de71
30ac1d0
4ae0c6d
4c8b006
c8b0b88
1f7152b
61c96b0
2a10be9
bb83d93
99a7d32
513edc5
d30eca8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,24 +116,24 @@ like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`. | |
|
|
||
| ```jldoctest varnamedvector-struct | ||
| julia> vnv[@varname(x)] | ||
| 2-element Vector{Real}: | ||
| 2-element Vector{Float64}: | ||
| 46.0 | ||
| 48.0 | ||
|
|
||
| julia> getindex_internal(vnv, :) | ||
| 8-element Vector{Real}: | ||
| 8-element Vector{Float64}: | ||
| 46.0 | ||
| 48.0 | ||
| 1 | ||
| 2 | ||
| 3 | ||
| 4 | ||
| 5 | ||
| 6 | ||
| 1.0 | ||
| 2.0 | ||
| 3.0 | ||
| 4.0 | ||
| 5.0 | ||
| 6.0 | ||
|
||
| ``` | ||
| """ | ||
| struct VarNamedVector{ | ||
| K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector | ||
| K<:VarName,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T} | ||
| } | ||
| """ | ||
| mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` | ||
|
|
@@ -143,7 +143,7 @@ struct VarNamedVector{ | |
| """ | ||
| vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` | ||
| """ | ||
| varnames::TVN # AbstractVector{<:VarName} | ||
| varnames::KVec | ||
|
|
||
| """ | ||
| vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has | ||
|
|
@@ -156,14 +156,14 @@ struct VarNamedVector{ | |
| vector of values of all variables; the value(s) of `vn` is/are | ||
| `vals[ranges[varname_to_index[vn]]]` | ||
| """ | ||
| vals::TVal # AbstractVector{<:Real} | ||
| vals::VVec | ||
|
|
||
| """ | ||
| vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable | ||
| that transforms the value of `vn` back to its original space, undoing any linking and | ||
| vectorisation | ||
| """ | ||
| transforms::TTrans | ||
| transforms::TVec | ||
|
|
||
| """ | ||
| vector of booleans indicating whether a variable has been explicitly transformed to | ||
|
|
@@ -186,14 +186,14 @@ struct VarNamedVector{ | |
|
|
||
| function VarNamedVector( | ||
| varname_to_index, | ||
| varnames::TVN, | ||
| varnames::KVec, | ||
| ranges, | ||
| vals::TVal, | ||
| transforms::TTrans, | ||
| vals::VVec, | ||
| transforms::TVec, | ||
| is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), | ||
| num_inactive=Dict{Int,Int}(); | ||
| check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, | ||
| ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} | ||
| ) where {K,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T}} | ||
| if check_consistency | ||
| if length(varnames) != length(ranges) || | ||
| length(varnames) != length(transforms) || | ||
|
|
@@ -257,7 +257,7 @@ struct VarNamedVector{ | |
| # tiny bit of thought. | ||
| end | ||
|
|
||
| return new{K,V,TVN,TVal,TTrans}( | ||
| return new{K,V,T,KVec,VVec,TVec}( | ||
| varname_to_index, | ||
| varnames, | ||
| ranges, | ||
|
|
@@ -269,9 +269,9 @@ struct VarNamedVector{ | |
| end | ||
| end | ||
|
|
||
| function VarNamedVector{K,V}() where {K,V} | ||
| function VarNamedVector{K,V,T}() where {K,V,T} | ||
| return VarNamedVector( | ||
| OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]; check_consistency=false | ||
| Dict{K,Int}(), K[], UnitRange{Int}[], V[], T[]; check_consistency=false | ||
| ) | ||
| end | ||
|
|
||
|
|
@@ -280,7 +280,7 @@ end | |
| # VarName and element types only as necessary, which would help keep them concrete. However, | ||
| # making that change here opens some other cans of worms related to how VarInfo uses | ||
| # BangBang, that I don't want to deal with right now. | ||
| VarNamedVector() = VarNamedVector{VarName,Real}() | ||
| VarNamedVector() = VarNamedVector{VarName,Real,An}() | ||
| function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) | ||
| return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) | ||
| end | ||
|
|
@@ -345,6 +345,12 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) | |
| vnv_left.num_inactive == vnv_right.num_inactive | ||
| end | ||
|
|
||
| function is_concretely_typed(vnv::VarNamedVector) | ||
| return isconcretetype(eltype(vnv.varnames)) && | ||
| isconcretetype(eltype(vnv.vals)) && | ||
| isconcretetype(eltype(vnv.transforms)) | ||
| end | ||
|
|
||
| getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] | ||
|
|
||
| getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] | ||
|
|
@@ -562,7 +568,7 @@ to be the default vectorisation transform. This undoes any possible linking. | |
| ```jldoctest varnamedvector-reset | ||
| julia> using DynamicPPL: VarNamedVector, @varname, reset! | ||
|
|
||
| julia> vnv = VarNamedVector(); | ||
| julia> vnv = VarNamedVector{VarName,Any,Any}(); | ||
|
|
||
| julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); | ||
|
|
||
|
|
@@ -810,7 +816,7 @@ end | |
| # with every ! call replaced with a !! call. | ||
|
|
||
| """ | ||
| loosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew}) | ||
| loosen_types!!(vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew}) | ||
|
|
||
| Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. | ||
|
|
||
|
|
@@ -821,7 +827,7 @@ transformations of type `TransNew` can be pushed to it. Some of the underlying s | |
| shared between `vnv` and the return value, and thus mutating one may affect the other. | ||
|
|
||
| # See also | ||
| [`tighten_types`](@ref) | ||
| [`tighten_types!!`](@ref) | ||
|
|
||
| # Examples | ||
|
|
||
|
|
@@ -836,7 +842,9 @@ julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) | |
| ERROR: MethodError: Cannot `convert` an object of type | ||
| [...] | ||
|
|
||
| julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), typeof(y_trans)); | ||
| julia> vnv_loose = DynamicPPL.loosen_types!!( | ||
| vnv, typeof(@varname(y)), Float64, typeof(y_trans) | ||
| ); | ||
|
|
||
| julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) | ||
|
|
||
|
|
@@ -847,40 +855,57 @@ julia> vnv_loose[@varname(y)] | |
| ``` | ||
| """ | ||
| function loosen_types!!( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've checked and both this and I should probably write a test to check this. Will need to think about how to do that. |
||
| vnv::VarNamedVector, ::Type{KNew}, ::Type{TransNew} | ||
| ) where {KNew,TransNew} | ||
| vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew} | ||
| ) where {KNew,VNew,TNew} | ||
| K = eltype(vnv.varnames) | ||
| Trans = eltype(vnv.transforms) | ||
| if KNew <: K && TransNew <: Trans | ||
| V = eltype(vnv.vals) | ||
| T = eltype(vnv.transforms) | ||
| if KNew <: K && VNew <: V && TNew <: T | ||
| return vnv | ||
| else | ||
| vn_type = promote_type(K, KNew) | ||
| transform_type = promote_type(Trans, TransNew) | ||
| return VarNamedVector( | ||
| Dict{vn_type,Int}(vnv.varname_to_index), | ||
| Vector{vn_type}(vnv.varnames), | ||
| vnv.ranges, | ||
| vnv.vals, | ||
| Vector{transform_type}(vnv.transforms), | ||
| vnv.is_unconstrained, | ||
| vnv.num_inactive; | ||
| check_consistency=false, | ||
| ) | ||
| val_type = promote_type(V, VNew) | ||
| transform_type = promote_type(T, TNew) | ||
| # This function would work the same way if the first if statement a few lines above | ||
| # was skipped, and we only checked for the below condition. However, the first one | ||
| # is constant propagated away at compile time (at least on Julia v1.11.7), whereas | ||
| # this one isn't. Hence we keep both for performance. | ||
| return if vn_type == K && val_type == V && transform_type == T | ||
| vnv | ||
| elseif isempty(vnv) | ||
| VarNamedVector(vn_type[], val_type[], transform_type[]) | ||
| else | ||
| # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but | ||
| # then here always revert to Vector. | ||
| VarNamedVector( | ||
| Dict{vn_type,Int}(vnv.varname_to_index), | ||
| Vector{vn_type}(vnv.varnames), | ||
| vnv.ranges, | ||
| Vector{val_type}(vnv.vals), | ||
| Vector{transform_type}(vnv.transforms), | ||
| vnv.is_unconstrained, | ||
| vnv.num_inactive; | ||
| check_consistency=false, | ||
| ) | ||
| end | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| tighten_types(vnv::VarNamedVector) | ||
| tighten_types!!(vnv::VarNamedVector) | ||
|
|
||
| Return a copy of `vnv` with the most concrete types possible. | ||
| Return a `VarNamedVector` like `vnv` with the most concrete types possible. | ||
|
|
||
| This function either returns `vnv` itself or new `VarNamedVector` with the same values in | ||
| it, but with the element types of various containers made as concrete as possible. | ||
|
|
||
| For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the | ||
| transforms are actually identity transformations, this function will return a new | ||
| `VarNamedVector` with the transforms vector having eltype `typeof(identity)`. | ||
|
|
||
| This is a lot like the reverse of [`loosen_types!!`](@ref), but with two notable | ||
| differences: Unlike `loosen_types!!`, this function does not mutate `vnv`; it also changes | ||
| not only the key and transform eltypes, but also the values eltype. | ||
| This is a lot like the reverse of [`loosen_types!!`](@ref). Like with `loosen_types!!`, the | ||
| return value may share some of its underlying storage with `vnv`, and thus mutating one may | ||
| affect the other. | ||
|
|
||
| # See also | ||
| [`loosen_types!!`](@ref) | ||
|
|
@@ -890,9 +915,9 @@ not only the key and transform eltypes, but also the values eltype. | |
| ```jldoctest varnamedvector-tighten-types | ||
| julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! | ||
|
|
||
| julia> vnv = VarNamedVector(); | ||
| julia> vnv = VarNamedVector(@varname(x) => Real[23], @varname(y) => randn(2,2)); | ||
|
|
||
| julia> setindex!(vnv, [23], @varname(x)) | ||
| julia> vnv = delete!(vnv, @varname(y)); | ||
|
|
||
| julia> eltype(vnv) | ||
| Real | ||
|
|
@@ -901,7 +926,7 @@ julia> vnv.transforms | |
| 1-element Vector{Any}: | ||
| identity (generic function with 1 method) | ||
|
|
||
| julia> vnv_tight = DynamicPPL.tighten_types(vnv); | ||
| julia> vnv_tight = DynamicPPL.tighten_types!!(vnv); | ||
|
|
||
| julia> eltype(vnv_tight) == Int | ||
| true | ||
|
|
@@ -911,17 +936,24 @@ julia> vnv_tight.transforms | |
| identity (generic function with 1 method) | ||
| ``` | ||
| """ | ||
| function tighten_types(vnv::VarNamedVector) | ||
| return VarNamedVector( | ||
| Dict(vnv.varname_to_index...), | ||
| map(identity, vnv.varnames), | ||
| copy(vnv.ranges), | ||
| map(identity, vnv.vals), | ||
| map(identity, vnv.transforms), | ||
| copy(vnv.is_unconstrained), | ||
| copy(vnv.num_inactive); | ||
| check_consistency=false, | ||
| ) | ||
| function tighten_types!!(vnv::VarNamedVector) | ||
| return if is_concretely_typed(vnv) | ||
| # There can not be anything to tighten, so short-circuit. | ||
| vnv | ||
| elseif isempty(vnv) | ||
| VarNamedVector() | ||
| else | ||
| VarNamedVector( | ||
| Dict(vnv.varname_to_index...), | ||
| [x for x in vnv.varnames], | ||
| vnv.ranges, | ||
| [x for x in vnv.vals], | ||
| [x for x in vnv.transforms], | ||
| vnv.is_unconstrained, | ||
| vnv.num_inactive; | ||
| check_consistency=false, | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) | ||
|
|
@@ -977,14 +1009,14 @@ function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=noth | |
| if transform === nothing | ||
| transform = identity | ||
| end | ||
| vnv = loosen_types!!(vnv, typeof(vn), typeof(transform)) | ||
| vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) | ||
| insert_internal!(vnv, val, vn, transform) | ||
| return vnv | ||
| end | ||
|
|
||
| function update_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing) | ||
| transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform | ||
| vnv = loosen_types!!(vnv, typeof(vn), typeof(transform_resolved)) | ||
| vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) | ||
| update_internal!(vnv, val, vn, transform) | ||
| return vnv | ||
| end | ||
|
|
@@ -1219,7 +1251,6 @@ julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) | |
| true | ||
| """ | ||
| function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) | ||
| # NOTE: This does not specialize types when possible. | ||
| vnv_new = similar(vnv) | ||
| # Return early if possible. | ||
| isempty(vnv) && return vnv_new | ||
|
|
@@ -1231,7 +1262,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) | |
| end | ||
| end | ||
|
|
||
| return vnv_new | ||
| return tighten_types!!(vnv_new) | ||
| end | ||
|
|
||
| """ | ||
|
|
@@ -1430,7 +1461,7 @@ true | |
| """ | ||
| function group_by_symbol(vnv::VarNamedVector) | ||
| symbols = unique(map(getsym, vnv.varnames)) | ||
| nt_vals = map(s -> tighten_types(subset(vnv, [VarName{s}()])), symbols) | ||
| nt_vals = map(s -> tighten_types!!(subset(vnv, [VarName{s}()])), symbols) | ||
| return OrderedDict(zip(symbols, nt_vals)) | ||
| end | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two don't actually ever mutate their inputs. Hence the
!!is questionable. However, they sometimes return the original object, sometimes an object that shares memory with the original object, so you should use them like you use!!functions: Always catch the return value and never assume that the return value is independent of the input.