@@ -101,6 +101,7 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo
101
101
logp:: Base.RefValue{Tlogp}
102
102
num_produce:: Base.RefValue{Int}
103
103
end
104
+ const VectorVarInfo = VarInfo{<: VarNameVector }
104
105
const UntypedVarInfo = VarInfo{<: Metadata }
105
106
const TypedVarInfo = VarInfo{<: NamedTuple }
106
107
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
@@ -125,6 +126,46 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
125
126
)
126
127
end
127
128
129
+ # No-op if we're already working with a `VarNameVector`.
130
+ metadata_to_varnamevector (vnv:: VarNameVector ) = vnv
131
+ function metadata_to_varnamevector (md:: Metadata )
132
+ idcs = copy (md. idcs)
133
+ vns = copy (md. vns)
134
+ ranges = copy (md. ranges)
135
+ vals = copy (md. vals)
136
+ transforms = map (md. dists) do dist
137
+ # TODO : Handle linked distributions.
138
+ from_vec_transform (dist)
139
+ end
140
+
141
+ return VarNameVector (
142
+ OrderedDict {eltype(keys(idcs)),Int} (idcs), vns, ranges, vals, transforms
143
+ )
144
+ end
145
+
146
+ function VectorVarInfo (vi:: UntypedVarInfo )
147
+ md = metadata_to_varnamevector (vi. metadata)
148
+ lp = getlogp (vi)
149
+ return VarInfo (md, Base. RefValue {eltype(lp)} (lp), Ref (get_num_produce (vi)))
150
+ end
151
+
152
+ function VectorVarInfo (vi:: TypedVarInfo )
153
+ md = map (metadata_to_varnamevector, vi. metadata)
154
+ lp = getlogp (vi)
155
+ return VarInfo (md, Base. RefValue {eltype(lp)} (lp), Ref (get_num_produce (vi)))
156
+ end
157
+
158
+ """
159
+ has_varnamevector(varinfo::VarInfo)
160
+
161
+ Returns `true` if `varinfo` uses `VarNameVector` as metadata.
162
+ """
163
+ has_varnamevector (vi) = false
164
+ function has_varnamevector (vi:: VarInfo )
165
+ return vi. metadata isa VarNameVector ||
166
+ (vi isa TypedVarInfo && any (Base. Fix2 (isa, VarNameVector), values (vi. metadata)))
167
+ end
168
+
128
169
function untyped_varinfo (
129
170
rng:: Random.AbstractRNG ,
130
171
model:: Model ,
@@ -321,6 +362,10 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
321
362
)
322
363
end
323
364
365
+ function merge_metadata (vnv_left:: VarNameVector , vnv_right:: VarNameVector )
366
+ return merge (vnv_left, vnv_right)
367
+ end
368
+
324
369
@generated function merge_metadata (
325
370
metadata_left:: NamedTuple{names_left} , metadata_right:: NamedTuple{names_right}
326
371
) where {names_left,names_right}
@@ -513,9 +558,13 @@ Return the distribution from which `vn` was sampled in `vi`.
513
558
"""
514
559
getdist (vi:: VarInfo , vn:: VarName ) = getdist (getmetadata (vi, vn), vn)
515
560
getdist (md:: Metadata , vn:: VarName ) = md. dists[getidx (md, vn)]
561
+ # HACK: we shouldn't need this
562
+ getdist (:: VarNameVector , :: VarName ) = nothing
516
563
517
564
getindex_internal (vi:: VarInfo , vn:: VarName ) = getindex_internal (getmetadata (vi, vn), vn)
518
565
getindex_internal (md:: Metadata , vn:: VarName ) = view (md. vals, getrange (md, vn))
566
+ # HACK: We shouldn't need this
567
+ getindex_internal (vnv:: VarNameVector , vn:: VarName ) = view (vnv. vals, getrange (vnv, vn))
519
568
520
569
function getindex_internal (vi:: VarInfo , vns:: Vector{<:VarName} )
521
570
return mapreduce (Base. Fix1 (getindex_internal, vi), vcat, vns)
535
584
function setval! (md:: Metadata , val, vn:: VarName )
536
585
return md. vals[getrange (md, vn)] = vectorize (getdist (md, vn), val)
537
586
end
587
+ function setval! (vnv:: VarNameVector , val, vn:: VarName )
588
+ return setindex_raw! (vnv, tovec (val), vn)
589
+ end
538
590
539
591
"""
540
592
getall(vi::VarInfo)
@@ -552,6 +604,7 @@ function getall(md::Metadata)
552
604
Base. Fix1 (getindex_internal, md), vcat, md. vns; init= similar (md. vals, 0 )
553
605
)
554
606
end
607
+ getall (vnv:: VarNameVector ) = vnv. vals
555
608
556
609
"""
557
610
setall!(vi::VarInfo, val)
@@ -567,6 +620,12 @@ function _setall!(metadata::Metadata, val)
567
620
metadata. vals[r] .= val[r]
568
621
end
569
622
end
623
+ function _setall! (vnv:: VarNameVector , val)
624
+ # TODO : Do something more efficient here.
625
+ for i in 1 : length (vnv)
626
+ vnv[i] = val[i]
627
+ end
628
+ end
570
629
@generated function _setall! (metadata:: NamedTuple{names} , val) where {names}
571
630
expr = Expr (:block )
572
631
start = :(1 )
@@ -599,6 +658,10 @@ function settrans!!(metadata::Metadata, trans::Bool, vn::VarName)
599
658
600
659
return metadata
601
660
end
661
+ function settrans!! (vnv:: VarNameVector , trans:: Bool , vn:: VarName )
662
+ settrans! (vnv, trans, vn)
663
+ return vnv
664
+ end
602
665
603
666
function settrans!! (vi:: VarInfo , trans:: Bool )
604
667
for vn in keys (vi)
940
1003
941
1004
# X -> R for all variables associated with given sampler
942
1005
function link!! (t:: DynamicTransformation , vi:: VarInfo , spl:: AbstractSampler , model:: Model )
1006
+ # If we're working with a `VarNameVector`, we always use immutable.
1007
+ has_varnamevector (vi) && return link (t, vi, spl, model)
943
1008
# Call `_link!` instead of `link!` to avoid deprecation warning.
944
1009
_link! (vi, spl)
945
1010
return vi
@@ -1035,6 +1100,8 @@ end
1035
1100
function invlink!! (
1036
1101
t:: DynamicTransformation , vi:: VarInfo , spl:: AbstractSampler , model:: Model
1037
1102
)
1103
+ # If we're working with a `VarNameVector`, we always use immutable.
1104
+ has_varnamevector (vi) && return invlink (t, vi, spl, model)
1038
1105
# Call `_invlink!` instead of `invlink!` to avoid deprecation warning.
1039
1106
_invlink! (vi, spl)
1040
1107
return vi
@@ -1260,6 +1327,66 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar
1260
1327
)
1261
1328
end
1262
1329
1330
+ function _link_metadata! (
1331
+ model:: Model , varinfo:: VarInfo , metadata:: VarNameVector , target_vns
1332
+ )
1333
+ # HACK: We ignore `target_vns` here.
1334
+ vns = keys (metadata)
1335
+ # Need to extract the priors from the model.
1336
+ dists = extract_priors (model, varinfo)
1337
+
1338
+ is_transformed = copy (metadata. is_transformed)
1339
+
1340
+ # Construct the linking transformations.
1341
+ link_transforms = map (vns) do vn
1342
+ # If `vn` is not part of `target_vns`, the `identity` transformation is used.
1343
+ if (target_vns != = nothing && vn ∉ target_vns)
1344
+ return identity
1345
+ end
1346
+
1347
+ # Otherwise, we derive the transformation from the distribution.
1348
+ is_transformed[getidx (metadata, vn)] = true
1349
+ internal_to_linked_internal_transform (varinfo, vn, dists[vn])
1350
+ end
1351
+ # Compute the transformed values.
1352
+ ys = map (vns, link_transforms) do vn, f
1353
+ # TODO : Do we need to handle scenarios where `vn` is not in `dists`?
1354
+ dist = dists[vn]
1355
+ x = getindex_internal (metadata, vn)
1356
+ y, logjac = with_logabsdet_jacobian (f, x)
1357
+ # Accumulate the log-abs-det jacobian correction.
1358
+ acclogp!! (varinfo, - logjac)
1359
+ # Return the transformed value.
1360
+ return y
1361
+ end
1362
+ # Extract the from-vec transformations.
1363
+ fromvec_transforms = map (from_vec_transform, ys)
1364
+ # Compose the transformations to form a full transformation from
1365
+ # unconstrained vector representation to constrained space.
1366
+ transforms = map (∘ , map (inverse, link_transforms), fromvec_transforms)
1367
+ # Convert to vector representation.
1368
+ yvecs = map (tovec, ys)
1369
+
1370
+ # Determine new ranges.
1371
+ ranges_new = similar (metadata. ranges)
1372
+ offset = 0
1373
+ for (i, v) in enumerate (yvecs)
1374
+ r_start, r_end = offset + 1 , length (v) + offset
1375
+ offset = r_end
1376
+ ranges_new[i] = r_start: r_end
1377
+ end
1378
+
1379
+ # Now we just create a new metadata with the new `vals` and `ranges`.
1380
+ return VarNameVector (
1381
+ metadata. varname_to_index,
1382
+ metadata. varnames,
1383
+ ranges_new,
1384
+ reduce (vcat, yvecs),
1385
+ transforms,
1386
+ is_transformed,
1387
+ )
1388
+ end
1389
+
1263
1390
function invlink (
1264
1391
:: DynamicTransformation , varinfo:: VarInfo , spl:: AbstractSampler , model:: Model
1265
1392
)
@@ -1360,6 +1487,55 @@ function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, targe
1360
1487
)
1361
1488
end
1362
1489
1490
+ function _invlink_metadata! (
1491
+ model:: Model , varinfo:: VarInfo , metadata:: VarNameVector , target_vns
1492
+ )
1493
+ # HACK: We ignore `target_vns` here.
1494
+ # TODO : Make use of `update!` to aovid copying values.
1495
+ # => Only need to allocate for transformations.
1496
+
1497
+ vns = keys (metadata)
1498
+ is_transformed = copy (metadata. is_transformed)
1499
+
1500
+ # Compute the transformed values.
1501
+ xs = map (vns) do vn
1502
+ f = gettransform (metadata, vn)
1503
+ y = getindex_internal (metadata, vn)
1504
+ # No need to use `with_reconstruct` as `f` will include this.
1505
+ x, logjac = with_logabsdet_jacobian (f, y)
1506
+ # Accumulate the log-abs-det jacobian correction.
1507
+ acclogp!! (varinfo, - logjac)
1508
+ # Mark as no longer transformed.
1509
+ is_transformed[getidx (metadata, vn)] = false
1510
+ # Return the transformed value.
1511
+ return x
1512
+ end
1513
+ # Compose the transformations to form a full transformation from
1514
+ # unconstrained vector representation to constrained space.
1515
+ transforms = map (from_vec_transform, xs)
1516
+ # Convert to vector representation.
1517
+ xvecs = map (tovec, xs)
1518
+
1519
+ # Determine new ranges.
1520
+ ranges_new = similar (metadata. ranges)
1521
+ offset = 0
1522
+ for (i, v) in enumerate (xvecs)
1523
+ r_start, r_end = offset + 1 , length (v) + offset
1524
+ offset = r_end
1525
+ ranges_new[i] = r_start: r_end
1526
+ end
1527
+
1528
+ # Now we just create a new metadata with the new `vals` and `ranges`.
1529
+ return VarNameVector (
1530
+ metadata. varname_to_index,
1531
+ metadata. varnames,
1532
+ ranges_new,
1533
+ reduce (vcat, xvecs),
1534
+ transforms,
1535
+ is_transformed,
1536
+ )
1537
+ end
1538
+
1363
1539
"""
1364
1540
islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior})
1365
1541
@@ -1429,6 +1605,14 @@ function getindex(vi::VarInfo, vn::VarName, dist::Distribution)
1429
1605
val = getindex_internal (vi, vn)
1430
1606
return from_maybe_linked_internal (vi, vn, dist, val)
1431
1607
end
1608
+ # HACK: Allows us to also work with `VarNameVector` where `dist` is not used,
1609
+ # but we instead use a transformation stored with the variable.
1610
+ function getindex (vi:: VarInfo , vn:: VarName , :: Nothing )
1611
+ if ! haskey (vi, vn)
1612
+ throw (KeyError (vn))
1613
+ end
1614
+ return getmetadata (vi, vn)[vn]
1615
+ end
1432
1616
1433
1617
function getindex (vi:: VarInfo , vns:: Vector{<:VarName} )
1434
1618
vals_linked = mapreduce (vcat, vns) do vn
@@ -1621,6 +1805,11 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
1621
1805
return meta
1622
1806
end
1623
1807
1808
+ function Base. push! (vnv:: VarNameVector , vn, r, dist, gidset, num_produce)
1809
+ f = from_vec_transform (dist)
1810
+ return push! (vnv, vn, r, f)
1811
+ end
1812
+
1624
1813
"""
1625
1814
setorder!(vi::VarInfo, vn::VarName, index::Int)
1626
1815
@@ -1635,6 +1824,7 @@ function setorder!(metadata::Metadata, vn::VarName, index::Int)
1635
1824
metadata. orders[metadata. idcs[vn]] = index
1636
1825
return metadata
1637
1826
end
1827
+ setorder! (vnv:: VarNameVector , :: VarName , :: Int ) = vnv
1638
1828
1639
1829
"""
1640
1830
getorder(vi::VarInfo, vn::VarName)
@@ -1660,6 +1850,8 @@ end
1660
1850
function is_flagged (metadata:: Metadata , vn:: VarName , flag:: String )
1661
1851
return metadata. flags[flag][getidx (metadata, vn)]
1662
1852
end
1853
+ # HACK: This is bad. Should we always return `true` here?
1854
+ is_flagged (:: VarNameVector , :: VarName , flag:: String ) = flag == " del" ? true : false
1663
1855
1664
1856
"""
1665
1857
unset_flag!(vi::VarInfo, vn::VarName, flag::String)
@@ -1674,6 +1866,7 @@ function unset_flag!(metadata::Metadata, vn::VarName, flag::String)
1674
1866
metadata. flags[flag][getidx (metadata, vn)] = false
1675
1867
return metadata
1676
1868
end
1869
+ unset_flag! (vnv:: VarNameVector , :: VarName , :: String ) = vnv
1677
1870
1678
1871
"""
1679
1872
set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler)
@@ -2027,6 +2220,8 @@ function values_from_metadata(md::Metadata)
2027
2220
)
2028
2221
end
2029
2222
2223
+ values_from_metadata (md:: VarNameVector ) = pairs (md)
2224
+
2030
2225
# Transforming from internal representation to distribution representation.
2031
2226
# Without `dist` argument: base on `dist` extracted from self.
2032
2227
function from_internal_transform (vi:: VarInfo , vn:: VarName )
@@ -2035,11 +2230,17 @@ end
2035
2230
function from_internal_transform (md:: Metadata , vn:: VarName )
2036
2231
return from_internal_transform (md, vn, getdist (md, vn))
2037
2232
end
2233
+ function from_internal_transform (md:: VarNameVector , vn:: VarName )
2234
+ return gettransform (md, vn)
2235
+ end
2038
2236
# With both `vn` and `dist` arguments: base on provided `dist`.
2039
2237
function from_internal_transform (vi:: VarInfo , vn:: VarName , dist)
2040
2238
return from_internal_transform (getmetadata (vi, vn), vn, dist)
2041
2239
end
2042
2240
from_internal_transform (:: Metadata , :: VarName , dist) = from_vec_transform (dist)
2241
+ function from_internal_transform (:: VarNameVector , :: VarName , dist)
2242
+ return from_vec_transform (dist)
2243
+ end
2043
2244
2044
2245
# Without `dist` argument: base on `dist` extracted from self.
2045
2246
function from_linked_internal_transform (vi:: VarInfo , vn:: VarName )
@@ -2048,6 +2249,9 @@ end
2048
2249
function from_linked_internal_transform (md:: Metadata , vn:: VarName )
2049
2250
return from_linked_internal_transform (md, vn, getdist (md, vn))
2050
2251
end
2252
+ function from_linked_internal_transform (md:: VarNameVector , vn:: VarName )
2253
+ return gettransform (md, vn)
2254
+ end
2051
2255
# With both `vn` and `dist` arguments: base on provided `dist`.
2052
2256
function from_linked_internal_transform (vi:: VarInfo , vn:: VarName , dist)
2053
2257
# Dispatch to metadata in case this alters the behavior.
@@ -2056,3 +2260,6 @@ end
2056
2260
function from_linked_internal_transform (:: Metadata , :: VarName , dist)
2057
2261
return from_linked_vec_transform (dist)
2058
2262
end
2263
+ function from_linked_internal_transform (:: VarNameVector , :: VarName , dist)
2264
+ return from_linked_vec_transform (dist)
2265
+ end
0 commit comments