1515Let `md` be an instance of `Metadata`:
1616- `md.vns` is the vector of all `VarName` instances.
1717- `md.idcs` is the dictionary that maps each `VarName` instance to its index in
18- `md.vns`, `md.ranges` `md.dists`, and `md.flags `.
18+ `md.vns`, `md.ranges` `md.dists`, and `md.trans `.
1919- `md.vns[md.idcs[vn]] == vn`.
2020- `md.dists[md.idcs[vn]]` is the distribution of `vn`.
2121- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`.
2222- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`.
23- - `md.flags ` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the
24- value of `flag ` corresponding to `vn`.
23+ - `md.trans ` is a Bitvector of true/false flags for whether a variable has been transformed.
24+ `md.trans[md.idcs[vn]]` is the value of `trans ` corresponding to `vn`.
2525
2626To make `md::Metadata` type stable, all the `md.vns` must have the same symbol
2727and distribution type. However, one can have a Julia variable, say `x`, that is a
@@ -56,8 +56,7 @@ struct Metadata{
5656 # Vector of distributions correpsonding to `vns`
5757 dists:: TDists # AbstractVector{<:Distribution}
5858
59- # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]`
60- flags:: Dict{String,BitVector}
59+ trans:: BitVector
6160end
6261
6362function Base.:(== )(md1:: Metadata , md2:: Metadata )
@@ -67,7 +66,7 @@ function Base.:(==)(md1::Metadata, md2::Metadata)
6766 md1. ranges == md2. ranges &&
6867 md1. vals == md2. vals &&
6968 md1. dists == md2. dists &&
70- md1. flags == md2. flags
69+ md1. trans == md2. trans
7170 )
7271end
7372
@@ -246,8 +245,8 @@ function typed_varinfo(vi::UntypedVarInfo)
246245 sym_idcs = Dict (a => i for (i, a) in enumerate (sym_vns))
247246 # New dists
248247 sym_dists = getindex .((meta. dists,), inds)
249- # New flags
250- sym_flags = Dict (a => meta. flags[a][ inds] for a in keys (meta . flags))
248+ # New trans
249+ sym_trans = meta. trans[ inds]
251250
252251 # Extract new ranges and vals
253252 _ranges = getindex .((meta. ranges,), inds)
@@ -263,7 +262,7 @@ function typed_varinfo(vi::UntypedVarInfo)
263262
264263 push! (
265264 new_metas,
266- Metadata (sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags ),
265+ Metadata (sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_trans ),
267266 )
268267 end
269268 nt = NamedTuple {syms_tuple} (Tuple (new_metas))
406405end
407406
408407function unflatten_metadata (md:: Metadata , x:: AbstractVector )
409- return Metadata (md. idcs, md. vns, md. ranges, x, md. dists, md. flags )
408+ return Metadata (md. idcs, md. vns, md. ranges, x, md. dists, md. trans )
410409end
411410
412411unflatten_metadata (vnv:: VarNamedVector , x:: AbstractVector ) = unflatten (vnv, x)
@@ -422,16 +421,15 @@ Construct an empty type unstable instance of `Metadata`.
422421"""
423422function Metadata ()
424423 vals = Vector {Real} ()
425- flags = Dict {String,BitVector} ()
426- flags[" trans" ] = BitVector ()
424+ trans = BitVector ()
427425
428426 return Metadata (
429427 Dict {VarName,Int} (),
430428 Vector {VarName} (),
431429 Vector {UnitRange{Int}} (),
432430 vals,
433431 Vector {Distribution} (),
434- flags ,
432+ trans ,
435433 )
436434end
437435
@@ -448,10 +446,7 @@ function empty!(meta::Metadata)
448446 empty! (meta. ranges)
449447 empty! (meta. vals)
450448 empty! (meta. dists)
451- for k in keys (meta. flags)
452- empty! (meta. flags[k])
453- end
454-
449+ empty! (meta. trans)
455450 return meta
456451end
457452
@@ -535,8 +530,8 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va
535530 offset = r[end ]
536531 end
537532
538- flags = Dict (k => v [indices_for_vns] for (k, v) in metadata . flags)
539- return Metadata (indices, vns, ranges, vals, metadata. dists[indices_for_vns], flags )
533+ trans = trans [indices_for_vns]
534+ return Metadata (indices, vns, ranges, vals, metadata. dists[indices_for_vns], trans )
540535end
541536
542537function Base. merge (varinfo_left:: VarInfo , varinfo_right:: VarInfo )
@@ -607,11 +602,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
607602 ranges = Vector {UnitRange{Int}} ()
608603 vals = T[]
609604 dists = D[]
610- flags = Dict {String,BitVector} ()
611- # Initialize the `flags`.
612- for k in union (keys (metadata_left. flags), keys (metadata_right. flags))
613- flags[k] = BitVector ()
614- end
605+ trans = BitVector ()
615606
616607 # Range offset.
617608 offset = 0
@@ -628,12 +619,10 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
628619 offset = r[end ]
629620 dist = getdist (metadata_for_vn, vn)
630621 push! (dists, dist)
631- for k in keys (flags)
632- push! (flags[k], is_flagged (metadata_for_vn, vn, k))
633- end
622+ push! (trans, is_trans (metadata_for_vn, vn))
634623 end
635624
636- return Metadata (idcs, vns, ranges, vals, dists, flags )
625+ return Metadata (idcs, vns, ranges, vals, dists, trans )
637626end
638627
639628const VarView = Union{Int,UnitRange,Vector{Int}}
@@ -807,12 +796,7 @@ function settrans!!(vi::VarInfo, trans::Bool, vn::VarName)
807796 return vi
808797end
809798function settrans!! (metadata:: Metadata , trans:: Bool , vn:: VarName )
810- if trans
811- set_flag! (metadata, vn, " trans" )
812- else
813- unset_flag! (metadata, vn, " trans" )
814- end
815-
799+ metadata. trans[getidx (metadata, vn)] = trans
816800 return metadata
817801end
818802
@@ -870,25 +854,6 @@ all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(v
870854 return expr
871855end
872856
873- # TODO (mhauru) These set_flag! methods return the VarInfo. They should probably be called
874- # set_flag!!.
875- """
876- set_flag!(vi::VarInfo, vn::VarName, flag::String)
877-
878- Set `vn`'s value for `flag` to `true` in `vi`.
879- """
880- function set_flag! (vi:: VarInfo , vn:: VarName , flag:: String )
881- set_flag! (getmetadata (vi, vn), vn, flag)
882- return vi
883- end
884- function set_flag! (md:: Metadata , vn:: VarName , flag:: String )
885- return md. flags[flag][getidx (md, vn)] = true
886- end
887-
888- function set_flag! (vnv:: VarNamedVector , :: VarName , flag:: String )
889- throw (ErrorException (" VarNamedVector does not support flags; Tried to set $(flag) ." ))
890- end
891-
892857# ###
893858# ### APIs for typed and untyped VarInfo
894859# ###
@@ -927,7 +892,7 @@ Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[]
927892end
928893
929894istrans (vi:: VarInfo , vn:: VarName ) = istrans (getmetadata (vi, vn), vn)
930- istrans (md:: Metadata , vn:: VarName ) = is_flagged (md, vn, " trans " )
895+ istrans (md:: Metadata , vn:: VarName ) = md . trans[ getidx (md, vn)]
931896
932897getaccs (vi:: VarInfo ) = vi. accs
933898setaccs!! (vi:: VarInfo , accs:: AccumulatorTuple ) = Accessors. @set vi. accs = accs
@@ -1300,7 +1265,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_
13001265 ranges_new,
13011266 reduce (vcat, vals_new),
13021267 metadata. dists,
1303- metadata. flags ,
1268+ metadata. trans ,
13041269 ),
13051270 cumulative_logjac
13061271end
@@ -1475,7 +1440,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ
14751440 ranges_new,
14761441 reduce (vcat, vals_new),
14771442 metadata. dists,
1478- metadata. flags ,
1443+ metadata. trans ,
14791444 ),
14801445 cumulative_inv_logjac
14811446end
@@ -1624,7 +1589,7 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
16241589 for accname in acckeys (vi)
16251590 push! (lines, (string (accname), getacc (vi, Val (accname))))
16261591 end
1627- push! (lines, (" flags " , vi. metadata. flags ))
1592+ push! (lines, (" trans " , vi. metadata. trans ))
16281593 max_name_length = maximum (map (length ∘ first, lines))
16291594 fmt = Printf. Format (" %-$(max_name_length) s" )
16301595 vi_str = (
@@ -1738,7 +1703,7 @@ function Base.push!(meta::Metadata, vn, r, dist)
17381703 push! (meta. ranges, (l + 1 ): (l + n))
17391704 append! (meta. vals, val)
17401705 push! (meta. dists, dist)
1741- push! (meta. flags[ " trans" ] , false )
1706+ push! (meta. trans, false )
17421707 return meta
17431708end
17441709
@@ -1751,39 +1716,6 @@ end
17511716# Rand & replaying method for VarInfo #
17521717# ######################################
17531718
1754- """
1755- is_flagged(vi::VarInfo, vn::VarName, flag::String)
1756-
1757- Check whether `vn` has a true value for `flag` in `vi`.
1758- """
1759- function is_flagged (vi:: VarInfo , vn:: VarName , flag:: String )
1760- return is_flagged (getmetadata (vi, vn), vn, flag)
1761- end
1762- function is_flagged (metadata:: Metadata , vn:: VarName , flag:: String )
1763- return metadata. flags[flag][getidx (metadata, vn)]
1764- end
1765- function is_flagged (:: VarNamedVector , :: VarName , flag:: String )
1766- throw (ErrorException (" VarNamedVector does not support flags; Tried to read $(flag) ." ))
1767- end
1768-
1769- """
1770- unset_flag!(vi::VarInfo, vn::VarName, flag::String
1771-
1772- Set `vn`'s value for `flag` to `false` in `vi`.
1773- """
1774- function unset_flag! (vi:: VarInfo , vn:: VarName , flag:: String )
1775- unset_flag! (getmetadata (vi, vn), vn, flag)
1776- return vi
1777- end
1778- function unset_flag! (metadata:: Metadata , vn:: VarName , flag:: String )
1779- metadata. flags[flag][getidx (metadata, vn)] = false
1780- return metadata
1781- end
1782-
1783- function unset_flag! (vnv:: VarNamedVector , :: VarName , flag:: String )
1784- throw (ErrorException (" VarNamedVector does not support flags; Tried to unset $(flag) ." ))
1785- end
1786-
17871719# TODO : Maybe rename or something?
17881720"""
17891721 _apply!(kernel!, vi::VarInfo, values, keys)
0 commit comments