Skip to content
66 changes: 36 additions & 30 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,26 @@ $(FIELDS)
The values for different variables are internally all stored in a single vector. For
instance,
```jldoctest varnamedvector-struct
julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!, getindex_internal
julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!!, update!!, getindex_internal

julia> vnv = VarNamedVector();

julia> setindex!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x));
julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x));

julia> setindex!(vnv, reshape(1:6, (2,3)), @varname(y));
julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y));

julia> vnv.vals
10-element Vector{Real}:
10-element Vector{Float64}:
0.0
0.0
0.0
0.0
1
2
3
4
5
6
1.0
2.0
3.0
4.0
5.0
6.0
```

The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to
Expand All @@ -91,20 +91,20 @@ If a variable is updated with a new value that is of a smaller dimension than th
value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive.

```jldoctest varnamedvector-struct
julia> update!(vnv, [46.0, 48.0], @varname(x))
julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x));

julia> vnv.vals
10-element Vector{Real}:
10-element Vector{Float64}:
46.0
48.0
0.0
0.0
1
2
3
4
5
6
1.0
2.0
3.0
4.0
5.0
6.0

julia> println(vnv.num_inactive);
Dict(1 => 2)
Expand Down Expand Up @@ -275,12 +275,7 @@ function VarNamedVector{K,V,T}() where {K,V,T}
)
end

# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). Simlarly the
# transform vector type above could then be Union{}[]. This would allow expanding the
# VarName and element types only as necessary, which would help keep them concrete. However,
# making that change here opens some other cans of worms related to how VarInfo uses
# BangBang, that I don't want to deal with right now.
VarNamedVector() = VarNamedVector{VarName,Real,An}()
VarNamedVector() = VarNamedVector{Union{},Union{},Union{}}()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably the most important and interesting change. By default VNVs are now created with an element type that allows storing nothing. push!! and setindex!! then expand that element type as needed. You should be able to push!! a value of any type into any VNV, and the container types will just flex. This is paving the way towards not having "untyped" VarInfos at all, but rather having them all be as typed as possible.

function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT)
return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency)
end
Expand All @@ -298,15 +293,16 @@ function VarNamedVector(
transforms=fill(identity, length(varnames));
check_consistency=CHECK_CONSISTENCY_DEFAULT,
)
if isempty(varnames) && isempty(orig_vals) && isempty(transforms)
return VarNamedVector{eltype(varnames),eltype(orig_vals),eltype(transforms)}()
end
# Convert `vals` into a vector of vectors.
vals_vecs = map(tovec, orig_vals)
transforms = map(
(t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals
)
# TODO: Is this really the way to do this?
if !(eltype(varnames) <: VarName)
varnames = convert(Vector{VarName}, varnames)
end
# Make `varnames` have as concrete an element type as possible.
varnames = [v for v in varnames]
varname_to_index = Dict{eltype(varnames),Int}(
vn => i for (i, vn) in enumerate(varnames)
)
Expand Down Expand Up @@ -1010,7 +1006,7 @@ function setindex_internal!!(
end
end

function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing)
function insert_internal!!(vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing)
if transform === nothing
transform = identity
end
Expand All @@ -1019,7 +1015,7 @@ function insert_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=noth
return vnv
end

function update_internal!!(vnv::VarNamedVector, val, vn::VarName, transform=nothing)
function update_internal!!(vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing)
transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform
vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved))
update_internal!(vnv, val, vn, transform)
Expand Down Expand Up @@ -1544,6 +1540,16 @@ function Base.delete!(vnv::VarNamedVector, vn::VarName)
return vnv
end

"""
delete!!(vnv::VarNamedVector, vn::VarName)

Like `delete!!`, but tightens the element types of the returned `VarNamedVector`.

# See also:
[`tighten_types!!`](@ref)
"""
BangBang.delete!!(vnv::VarNamedVector, vn::VarName) = tighten_types!!(delete!(vnv, vn))

"""
values_as(vnv::VarNamedVector[, T])

Expand Down
2 changes: 1 addition & 1 deletion test/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end
# Empty.
vnv = DynamicPPL.VarNamedVector()
@test isempty(vnv)
@test eltype(vnv) == Real
@test eltype(vnv) == Union{}

# Empty with types.
vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}()
Expand Down