Skip to content

Commit 1542c56

Browse files
authored
VNV type stability improvements and tests (#1102)
* Use tighter element types in VNV * Add type tightness tests for VNV * Fix some uses of OrderedDict in VNV tests * Improvements to VNV loosen/tighten types * Run formatter * Don't recontiguify VNVs unnecessarily * contiguify VNV after (inv)linking * Negate with !, not ~
1 parent 8c34394 commit 1542c56

File tree

7 files changed

+137
-16
lines changed

7 files changed

+137
-16
lines changed

src/varinfo.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,10 @@ function _link_metadata!!(
12971297
metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked)
12981298
set_transformed!(metadata, true, vn)
12991299
end
1300+
# Linking can often change the sizes of variables, causing inactive elements. We don't
1301+
# want to keep them around, since typically linking is done once and then the VarInfo
1302+
# is evaluated multiple times. Hence we contiguify here.
1303+
metadata = contiguify!(metadata)
13001304
return metadata, cumulative_logjac
13011305
end
13021306

@@ -1465,6 +1469,10 @@ function _invlink_metadata!!(
14651469
metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform)
14661470
set_transformed!(metadata, false, vn)
14671471
end
1472+
# Linking can often change the sizes of variables, causing inactive elements. We don't
1473+
# want to keep them around, since typically linking is done once and then the VarInfo
1474+
# is evaluated multiple times. Hence we contiguify here.
1475+
metadata = contiguify!(metadata)
14681476
return metadata, cumulative_inv_logjac
14691477
end
14701478

src/varnamedvector.jl

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,13 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
341341
vnv_left.num_inactive == vnv_right.num_inactive
342342
end
343343

344-
function is_concretely_typed(vnv::VarNamedVector)
345-
return isconcretetype(eltype(vnv.varnames)) &&
346-
isconcretetype(eltype(vnv.vals)) &&
347-
isconcretetype(eltype(vnv.transforms))
344+
function is_tightly_typed(vnv::VarNamedVector)
345+
k = eltype(vnv.varnames)
346+
v = eltype(vnv.vals)
347+
t = eltype(vnv.transforms)
348+
return (isconcretetype(k) || k === Union{}) &&
349+
(isconcretetype(v) || v === Union{}) &&
350+
(isconcretetype(t) || t === Union{})
348351
end
349352

350353
getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn]
@@ -880,7 +883,16 @@ function loosen_types!!(
880883
return if vn_type == K && val_type == V && transform_type == T
881884
vnv
882885
elseif isempty(vnv)
883-
VarNamedVector(vn_type[], val_type[], transform_type[])
886+
VarNamedVector(
887+
Dict{vn_type,Int}(),
888+
Vector{vn_type}(),
889+
UnitRange{Int}[],
890+
Vector{val_type}(),
891+
Vector{transform_type}(),
892+
BitVector(),
893+
Dict{Int,Int}();
894+
check_consistency=false,
895+
)
884896
else
885897
# TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but
886898
# then here always revert to Vector.
@@ -944,7 +956,7 @@ julia> vnv_tight.transforms
944956
```
945957
"""
946958
function tighten_types!!(vnv::VarNamedVector)
947-
return if is_concretely_typed(vnv)
959+
return if is_tightly_typed(vnv)
948960
# There can not be anything to tighten, so short-circuit.
949961
vnv
950962
elseif isempty(vnv)
@@ -1020,6 +1032,7 @@ function insert_internal!!(
10201032
end
10211033
vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform))
10221034
insert_internal!(vnv, val, vn, transform)
1035+
vnv = tighten_types!!(vnv)
10231036
return vnv
10241037
end
10251038

@@ -1029,6 +1042,7 @@ function update_internal!!(
10291042
transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform
10301043
vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved))
10311044
update_internal!(vnv, val, vn, transform)
1045+
vnv = tighten_types!!(vnv)
10321046
return vnv
10331047
end
10341048

@@ -1104,6 +1118,9 @@ care about them.
11041118
11051119
This is in a sense the reverse operation of `vnv[:]`.
11061120
1121+
The return value may share memory with the input `vnv`, and thus one can not be mutated
1122+
safely without affecting the other.
1123+
11071124
Unflatten recontiguifies the internal storage, getting rid of any inactive entries.
11081125
11091126
# Examples
@@ -1125,15 +1142,20 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector)
11251142
),
11261143
)
11271144
end
1128-
new_ranges = deepcopy(vnv.ranges)
1129-
recontiguify_ranges!(new_ranges)
1145+
new_ranges = vnv.ranges
1146+
num_inactive = vnv.num_inactive
1147+
if has_inactive(vnv)
1148+
new_ranges = recontiguify_ranges!(new_ranges)
1149+
num_inactive = Dict{Int,Int}()
1150+
end
11301151
return VarNamedVector(
11311152
vnv.varname_to_index,
11321153
vnv.varnames,
11331154
new_ranges,
11341155
vals,
11351156
vnv.transforms,
1136-
vnv.is_unconstrained;
1157+
vnv.is_unconstrained,
1158+
num_inactive;
11371159
check_consistency=false,
11381160
)
11391161
end
@@ -1428,6 +1450,9 @@ julia> vnv[@varname(x)] # All the values are still there.
14281450
```
14291451
"""
14301452
function contiguify!(vnv::VarNamedVector)
1453+
if !has_inactive(vnv)
1454+
return vnv
1455+
end
14311456
# Extract the re-contiguified values.
14321457
# NOTE: We need to do this before we update the ranges.
14331458
old_vals = copy(vnv.vals)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
44
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
55
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
66
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
78
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
89
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
910
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -34,6 +35,7 @@ AbstractMCMC = "5"
3435
AbstractPPL = "0.13"
3536
Accessors = "0.1"
3637
Aqua = "0.8"
38+
BangBang = "0.4"
3739
Bijectors = "0.15.1"
3840
Combinatorics = "1"
3941
DifferentiationInterface = "0.6.41, 0.7"

test/accumulators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ using DynamicPPL:
117117
@test at_all64[:LogLikelihood] == ll_f64
118118

119119
@test haskey(AccumulatorTuple(lp_f64), Val(:LogPrior))
120-
@test ~haskey(AccumulatorTuple(lp_f64), Val(:LogLikelihood))
120+
@test !haskey(AccumulatorTuple(lp_f64), Val(:LogLikelihood))
121121
@test length(AccumulatorTuple(lp_f64, ll_f64)) == 2
122122
@test keys(at_all64) == (:LogPrior, :LogLikelihood)
123123
@test collect(at_all64) == [lp_f64, ll_f64]

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ADTypes
33
using DynamicPPL
44
using AbstractMCMC
55
using AbstractPPL
6+
using BangBang: delete!!, setindex!!
67
using Bijectors
78
using DifferentiationInterface
89
using Distributions

test/varinfo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ end
7171
r = rand(dist)
7272

7373
@test isempty(vi)
74-
@test ~haskey(vi, vn)
74+
@test !haskey(vi, vn)
7575
@test !(vn in keys(vi))
7676
vi = push!!(vi, vn, r, dist)
77-
@test ~isempty(vi)
77+
@test !isempty(vi)
7878
@test haskey(vi, vn)
7979
@test vn in keys(vi)
8080

@@ -95,7 +95,7 @@ end
9595
vi = empty!!(vi)
9696
@test isempty(vi)
9797
vi = push!!(vi, vn, r, dist)
98-
@test ~isempty(vi)
98+
@test !isempty(vi)
9999
end
100100

101101
test_base(VarInfo())

test/varnamedvector.jl

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val)
7979
end
8080
function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals)
8181
if need_varnames_relaxation(vnv, vns, vals)
82-
varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index)
82+
varname_to_index_new = convert(Dict{VarName,Int}, vnv.varname_to_index)
8383
varnames_new = convert(Vector{VarName}, vnv.varnames)
8484
else
8585
varname_to_index_new = vnv.varname_to_index
@@ -517,7 +517,7 @@ end
517517
@testset "deterministic" begin
518518
n = 5
519519
vn = @varname(x)
520-
vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true]))
520+
vnv = DynamicPPL.VarNamedVector(Dict(vn => [true]))
521521
@test !DynamicPPL.has_inactive(vnv)
522522
# Growing should not create inactive ranges.
523523
for i in 1:n
@@ -543,7 +543,7 @@ end
543543
@testset "random" begin
544544
n = 5
545545
vn = @varname(x)
546-
vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true]))
546+
vnv = DynamicPPL.VarNamedVector(Dict(vn => [true]))
547547
@test !DynamicPPL.has_inactive(vnv)
548548

549549
# Insert a bunch of random-length vectors.
@@ -579,6 +579,91 @@ end
579579
@test is_transformed(vnv, @varname(t[1]))
580580
@test subset(vnv, vns) == vnv
581581
end
582+
583+
@testset "loosen and tighten types" begin
584+
"""
585+
test_tightenability(vnv::VarNamedVector)
586+
587+
Test that tighten_types!! is a no-op on `vnv`.
588+
"""
589+
function test_tightenability(vnv::DynamicPPL.VarNamedVector)
590+
@test vnv == DynamicPPL.tighten_types!!(deepcopy(vnv))
591+
# TODO(mhauru) We would like to check something more stringent here, namely that
592+
# the operation is compiled to a direct no-op, with no instructions at all. I
593+
# don't know how to do that though, so for now we just check that it doesn't
594+
# allocate.
595+
@allocations(DynamicPPL.tighten_types!!(vnv)) == 0
596+
return nothing
597+
end
598+
599+
vn = @varname(a[1])
600+
# Test that tighten_types!! is a no-op on an empty VarNamedVector.
601+
vnv = DynamicPPL.VarNamedVector()
602+
@test DynamicPPL.is_tightly_typed(vnv)
603+
test_tightenability(vnv)
604+
# Also check that it literally returns the same object, and both tighten and loosen
605+
# are type stable.
606+
@test vnv === DynamicPPL.tighten_types!!(vnv)
607+
@inferred DynamicPPL.tighten_types!!(vnv)
608+
@inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any)
609+
# Likewise for a VarNamedVector with something pushed into it.
610+
vnv = DynamicPPL.VarNamedVector()
611+
vnv = setindex!!(vnv, 1.0, vn)
612+
@test DynamicPPL.is_tightly_typed(vnv)
613+
test_tightenability(vnv)
614+
@test vnv === DynamicPPL.tighten_types!!(vnv)
615+
@inferred DynamicPPL.tighten_types!!(vnv)
616+
@inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any)
617+
# Likewise for a VarNamedVector with abstract element-types, when that is needed for
618+
# the current contents because mixed types have been pushed into it. However, this
619+
# time, since the types are only as tight as they can be, but not actually concrete,
620+
# tighten_types!! can't be type stable.
621+
vnv = DynamicPPL.VarNamedVector()
622+
vnv = setindex!!(vnv, 1.0, vn)
623+
vnv = setindex!!(vnv, 2, @varname(b))
624+
@test !DynamicPPL.is_tightly_typed(vnv)
625+
test_tightenability(vnv)
626+
@inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any)
627+
# Likewise when first mixed types are pushed, but then deleted.
628+
vnv = DynamicPPL.VarNamedVector()
629+
vnv = setindex!!(vnv, 1.0, vn)
630+
vnv = setindex!!(vnv, 2, @varname(b))
631+
@test !DynamicPPL.is_tightly_typed(vnv)
632+
vnv = delete!!(vnv, vn)
633+
@test DynamicPPL.is_tightly_typed(vnv)
634+
test_tightenability(vnv)
635+
@test vnv === DynamicPPL.tighten_types!!(vnv)
636+
@inferred DynamicPPL.tighten_types!!(vnv)
637+
@inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any)
638+
639+
# Test that loosen_types!! does really loosen them and that tighten_types!! reverts
640+
# that.
641+
vnv = DynamicPPL.VarNamedVector()
642+
vnv = setindex!!(vnv, 1.0, vn)
643+
@test DynamicPPL.is_tightly_typed(vnv)
644+
k = eltype(vnv.varnames)
645+
e = eltype(vnv.vals)
646+
t = eltype(vnv.transforms)
647+
# Loosen key type.
648+
vnv = @inferred DynamicPPL.loosen_types!!(vnv, VarName, e, t)
649+
@test !DynamicPPL.is_tightly_typed(vnv)
650+
vnv = DynamicPPL.tighten_types!!(vnv)
651+
@test DynamicPPL.is_tightly_typed(vnv)
652+
# Loosen element type
653+
vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, Real, t)
654+
@test !DynamicPPL.is_tightly_typed(vnv)
655+
vnv = DynamicPPL.tighten_types!!(vnv)
656+
@test DynamicPPL.is_tightly_typed(vnv)
657+
# Loosen transformation type
658+
vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, Function)
659+
@test !DynamicPPL.is_tightly_typed(vnv)
660+
vnv = DynamicPPL.tighten_types!!(vnv)
661+
@test DynamicPPL.is_tightly_typed(vnv)
662+
# Loosening to the same types as currently should do nothing.
663+
vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, t)
664+
@test DynamicPPL.is_tightly_typed(vnv)
665+
@allocations(DynamicPPL.loosen_types!!(vnv, k, e, t)) == 0
666+
end
582667
end
583668

584669
@testset "VarInfo + VarNamedVector" begin

0 commit comments

Comments
 (0)