Skip to content

Commit 756cc25

Browse files
committed
Replace Medata.flags with Metadata.trans
1 parent c08cfa5 commit 756cc25

File tree

6 files changed

+36
-117
lines changed

6 files changed

+36
-117
lines changed

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ The separation of these functions was primarily implemented to avoid performing
5454

5555
Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.
5656

57+
The only other flag, other than `"del"`, that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed. One can simply use `istrans` and a newly exported function called `settrans!!` instead.
58+
5759
**Other changes**
5860

5961
### Reimplementation of functions using `InitContext`

docs/src/api.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,8 @@ The [Transformations section below](#Transformations) describes the methods used
329329
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.
330330

331331
```@docs
332-
set_flag!
333-
unset_flag!
334-
is_flagged
332+
istrans
333+
settrans!!
335334
```
336335

337336
```@docs
@@ -423,8 +422,6 @@ DynamicPPL.StaticTransformation
423422
```
424423

425424
```@docs
426-
DynamicPPL.istrans
427-
DynamicPPL.settrans!!
428425
DynamicPPL.transformation
429426
DynamicPPL.link
430427
DynamicPPL.invlink

src/DynamicPPL.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ export AbstractVarInfo,
7070
acclogjac!!,
7171
acclogprior!!,
7272
accloglikelihood!!,
73-
is_flagged,
74-
set_flag!,
75-
unset_flag!,
7673
istrans,
74+
settrans!!,
7775
link,
7876
link!!,
7977
invlink,

src/threadsafe.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,6 @@ end
185185
values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
186186
values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)
187187

188-
function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
189-
return unset_flag!(vi.varinfo, vn, flag)
190-
end
191-
function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
192-
return is_flagged(vi.varinfo, vn, flag)
193-
end
194-
195188
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
196189
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
197190
end

src/varinfo.jl

Lines changed: 23 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ not.
1515
Let `md` be an instance of `Metadata`:
1616
- `md.vns` is the vector of all `VarName` instances.
1717
- `md.idcs` is the dictionary that maps each `VarName` instance to its index in
18-
`md.vns`, `md.ranges` `md.dists`, and `md.flags`.
18+
`md.vns`, `md.ranges` `md.dists`, and `md.trans`.
1919
- `md.vns[md.idcs[vn]] == vn`.
2020
- `md.dists[md.idcs[vn]]` is the distribution of `vn`.
2121
- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`.
2222
- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`.
23-
- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the
24-
value of `flag` corresponding to `vn`.
23+
- `md.trans` is a Bitvector of true/false flags for whether a variable has been transformed.
24+
`md.trans[md.idcs[vn]]` is the value of `trans` corresponding to `vn`.
2525
2626
To make `md::Metadata` type stable, all the `md.vns` must have the same symbol
2727
and distribution type. However, one can have a Julia variable, say `x`, that is a
@@ -56,8 +56,7 @@ struct Metadata{
5656
# Vector of distributions correpsonding to `vns`
5757
dists::TDists # AbstractVector{<:Distribution}
5858

59-
# Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]`
60-
flags::Dict{String,BitVector}
59+
trans::BitVector
6160
end
6261

6362
function Base.:(==)(md1::Metadata, md2::Metadata)
@@ -67,7 +66,7 @@ function Base.:(==)(md1::Metadata, md2::Metadata)
6766
md1.ranges == md2.ranges &&
6867
md1.vals == md2.vals &&
6968
md1.dists == md2.dists &&
70-
md1.flags == md2.flags
69+
md1.trans == md2.trans
7170
)
7271
end
7372

@@ -246,8 +245,8 @@ function typed_varinfo(vi::UntypedVarInfo)
246245
sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns))
247246
# New dists
248247
sym_dists = getindex.((meta.dists,), inds)
249-
# New flags
250-
sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags))
248+
# New trans
249+
sym_trans = meta.trans[inds]
251250

252251
# Extract new ranges and vals
253252
_ranges = getindex.((meta.ranges,), inds)
@@ -263,7 +262,7 @@ function typed_varinfo(vi::UntypedVarInfo)
263262

264263
push!(
265264
new_metas,
266-
Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags),
265+
Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_trans),
267266
)
268267
end
269268
nt = NamedTuple{syms_tuple}(Tuple(new_metas))
@@ -406,7 +405,7 @@ end
406405
end
407406

408407
function unflatten_metadata(md::Metadata, x::AbstractVector)
409-
return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.flags)
408+
return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.trans)
410409
end
411410

412411
unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)
@@ -422,16 +421,15 @@ Construct an empty type unstable instance of `Metadata`.
422421
"""
423422
function Metadata()
424423
vals = Vector{Real}()
425-
flags = Dict{String,BitVector}()
426-
flags["trans"] = BitVector()
424+
trans = BitVector()
427425

428426
return Metadata(
429427
Dict{VarName,Int}(),
430428
Vector{VarName}(),
431429
Vector{UnitRange{Int}}(),
432430
vals,
433431
Vector{Distribution}(),
434-
flags,
432+
trans,
435433
)
436434
end
437435

@@ -448,10 +446,7 @@ function empty!(meta::Metadata)
448446
empty!(meta.ranges)
449447
empty!(meta.vals)
450448
empty!(meta.dists)
451-
for k in keys(meta.flags)
452-
empty!(meta.flags[k])
453-
end
454-
449+
empty!(meta.trans)
455450
return meta
456451
end
457452

@@ -535,8 +530,8 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va
535530
offset = r[end]
536531
end
537532

538-
flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags)
539-
return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], flags)
533+
trans = trans[indices_for_vns]
534+
return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], trans)
540535
end
541536

542537
function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
@@ -607,11 +602,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
607602
ranges = Vector{UnitRange{Int}}()
608603
vals = T[]
609604
dists = D[]
610-
flags = Dict{String,BitVector}()
611-
# Initialize the `flags`.
612-
for k in union(keys(metadata_left.flags), keys(metadata_right.flags))
613-
flags[k] = BitVector()
614-
end
605+
trans = BitVector()
615606

616607
# Range offset.
617608
offset = 0
@@ -628,12 +619,10 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
628619
offset = r[end]
629620
dist = getdist(metadata_for_vn, vn)
630621
push!(dists, dist)
631-
for k in keys(flags)
632-
push!(flags[k], is_flagged(metadata_for_vn, vn, k))
633-
end
622+
push!(trans, is_trans(metadata_for_vn, vn))
634623
end
635624

636-
return Metadata(idcs, vns, ranges, vals, dists, flags)
625+
return Metadata(idcs, vns, ranges, vals, dists, trans)
637626
end
638627

639628
const VarView = Union{Int,UnitRange,Vector{Int}}
@@ -807,12 +796,7 @@ function settrans!!(vi::VarInfo, trans::Bool, vn::VarName)
807796
return vi
808797
end
809798
function settrans!!(metadata::Metadata, trans::Bool, vn::VarName)
810-
if trans
811-
set_flag!(metadata, vn, "trans")
812-
else
813-
unset_flag!(metadata, vn, "trans")
814-
end
815-
799+
metadata.trans[getidx(metadata, vn)] = trans
816800
return metadata
817801
end
818802

@@ -870,25 +854,6 @@ all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(v
870854
return expr
871855
end
872856

873-
# TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called
874-
# set_flag!!.
875-
"""
876-
set_flag!(vi::VarInfo, vn::VarName, flag::String)
877-
878-
Set `vn`'s value for `flag` to `true` in `vi`.
879-
"""
880-
function set_flag!(vi::VarInfo, vn::VarName, flag::String)
881-
set_flag!(getmetadata(vi, vn), vn, flag)
882-
return vi
883-
end
884-
function set_flag!(md::Metadata, vn::VarName, flag::String)
885-
return md.flags[flag][getidx(md, vn)] = true
886-
end
887-
888-
function set_flag!(vnv::VarNamedVector, ::VarName, flag::String)
889-
throw(ErrorException("VarNamedVector does not support flags; Tried to set $(flag)."))
890-
end
891-
892857
####
893858
#### APIs for typed and untyped VarInfo
894859
####
@@ -927,7 +892,7 @@ Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[]
927892
end
928893

929894
istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn)
930-
istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans")
895+
istrans(md::Metadata, vn::VarName) = md.trans[getidx(md, vn)]
931896

932897
getaccs(vi::VarInfo) = vi.accs
933898
setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs
@@ -1300,7 +1265,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_
13001265
ranges_new,
13011266
reduce(vcat, vals_new),
13021267
metadata.dists,
1303-
metadata.flags,
1268+
metadata.trans,
13041269
),
13051270
cumulative_logjac
13061271
end
@@ -1475,7 +1440,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ
14751440
ranges_new,
14761441
reduce(vcat, vals_new),
14771442
metadata.dists,
1478-
metadata.flags,
1443+
metadata.trans,
14791444
),
14801445
cumulative_inv_logjac
14811446
end
@@ -1624,7 +1589,7 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
16241589
for accname in acckeys(vi)
16251590
push!(lines, (string(accname), getacc(vi, Val(accname))))
16261591
end
1627-
push!(lines, ("flags", vi.metadata.flags))
1592+
push!(lines, ("trans", vi.metadata.trans))
16281593
max_name_length = maximum(map(length first, lines))
16291594
fmt = Printf.Format("%-$(max_name_length)s")
16301595
vi_str = (
@@ -1738,7 +1703,7 @@ function Base.push!(meta::Metadata, vn, r, dist)
17381703
push!(meta.ranges, (l + 1):(l + n))
17391704
append!(meta.vals, val)
17401705
push!(meta.dists, dist)
1741-
push!(meta.flags["trans"], false)
1706+
push!(meta.trans, false)
17421707
return meta
17431708
end
17441709

@@ -1751,39 +1716,6 @@ end
17511716
# Rand & replaying method for VarInfo #
17521717
#######################################
17531718

1754-
"""
1755-
is_flagged(vi::VarInfo, vn::VarName, flag::String)
1756-
1757-
Check whether `vn` has a true value for `flag` in `vi`.
1758-
"""
1759-
function is_flagged(vi::VarInfo, vn::VarName, flag::String)
1760-
return is_flagged(getmetadata(vi, vn), vn, flag)
1761-
end
1762-
function is_flagged(metadata::Metadata, vn::VarName, flag::String)
1763-
return metadata.flags[flag][getidx(metadata, vn)]
1764-
end
1765-
function is_flagged(::VarNamedVector, ::VarName, flag::String)
1766-
throw(ErrorException("VarNamedVector does not support flags; Tried to read $(flag)."))
1767-
end
1768-
1769-
"""
1770-
unset_flag!(vi::VarInfo, vn::VarName, flag::String
1771-
1772-
Set `vn`'s value for `flag` to `false` in `vi`.
1773-
"""
1774-
function unset_flag!(vi::VarInfo, vn::VarName, flag::String)
1775-
unset_flag!(getmetadata(vi, vn), vn, flag)
1776-
return vi
1777-
end
1778-
function unset_flag!(metadata::Metadata, vn::VarName, flag::String)
1779-
metadata.flags[flag][getidx(metadata, vn)] = false
1780-
return metadata
1781-
end
1782-
1783-
function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String)
1784-
throw(ErrorException("VarNamedVector does not support flags; Tried to unset $(flag)."))
1785-
end
1786-
17871719
# TODO: Maybe rename or something?
17881720
"""
17891721
_apply!(kernel!, vi::VarInfo, values, keys)

test/varinfo.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ end
4848
ind = meta.idcs[vn]
4949
tind = fmeta.idcs[vn]
5050
@test meta.dists[ind] == fmeta.dists[tind]
51-
for flag in keys(meta.flags)
52-
@test meta.flags[flag][ind] == fmeta.flags[flag][tind]
53-
end
51+
@test meta.trans[ind] == fmeta.trans[tind]
5452
range = meta.ranges[ind]
5553
trange = fmeta.ranges[tind]
5654
@test all(meta.vals[range] .== fmeta.vals[trange])
@@ -285,9 +283,8 @@ end
285283
@test all_accs_same(vi, vi_orig)
286284
end
287285

288-
@testset "flags" begin
289-
# Test flag setting:
290-
# is_flagged, set_flag!, unset_flag!
286+
@testset "trans flag" begin
287+
# Test istrans and settrans!!
291288
function test_varinfo!(vi)
292289
vn_x = @varname x
293290
dist = Normal(0, 1)
@@ -296,13 +293,13 @@ end
296293
push!!(vi, vn_x, r, dist)
297294

298295
# trans is set by default
299-
@test !is_flagged(vi, vn_x, "trans")
296+
@test !istrans(vi, vn_x)
300297

301-
set_flag!(vi, vn_x, "trans")
302-
@test is_flagged(vi, vn_x, "trans")
298+
vi = settrans!!(vi, vn_x, true)
299+
@test istrans(vi, vn_x)
303300

304-
unset_flag!(vi, vn_x, "trans")
305-
@test !is_flagged(vi, vn_x, "trans")
301+
vi = settrans!!!(vi, vn_x, false)
302+
@test !istrans(vi, vn_x)
306303
end
307304
vi = VarInfo()
308305
test_varinfo!(vi)

0 commit comments

Comments
 (0)