Skip to content

Commit 54792f4

Browse files
committed
Revert "temporarily removed VarNameVector completely"
This reverts commit 95dc8e3.
1 parent 3d823ac commit 54792f4

File tree

10 files changed

+1760
-2
lines changed

10 files changed

+1760
-2
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ makedocs(;
2020
"Home" => "index.md",
2121
"API" => "api.md",
2222
"Tutorials" => ["tutorials/prob-interface.md"],
23-
"Internals" => ["internals/transformations.md"],
23+
"Internals" => ["internals/varinfo.md", "internals/transformations.md"],
2424
],
2525
checkdocs=:exports,
2626
doctest=false,

docs/src/internals/varinfo.md

Lines changed: 309 additions & 0 deletions
Large diffs are not rendered by default.

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export AbstractVarInfo,
4444
UntypedVarInfo,
4545
TypedVarInfo,
4646
SimpleVarInfo,
47+
VarNameVector,
4748
push!!,
4849
empty!!,
4950
subset,
@@ -158,6 +159,7 @@ include("sampler.jl")
158159
include("varname.jl")
159160
include("distribution_wrappers.jl")
160161
include("contexts.jl")
162+
include("varnamevector.jl")
161163
include("abstract_varinfo.jl")
162164
include("threadsafe.jl")
163165
include("varinfo.jl")

src/threadsafe.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp)
5555
return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps)
5656
end
5757

58+
has_varnamevector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamevector(vi.varinfo)
59+
5860
function BangBang.push!!(
5961
vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
6062
)

src/varinfo.jl

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo
101101
logp::Base.RefValue{Tlogp}
102102
num_produce::Base.RefValue{Int}
103103
end
104+
const VectorVarInfo = VarInfo{<:VarNameVector}
104105
const UntypedVarInfo = VarInfo{<:Metadata}
105106
const TypedVarInfo = VarInfo{<:NamedTuple}
106107
const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
@@ -125,6 +126,46 @@ function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
125126
)
126127
end
127128

129+
# No-op if we're already working with a `VarNameVector`.
130+
metadata_to_varnamevector(vnv::VarNameVector) = vnv
131+
function metadata_to_varnamevector(md::Metadata)
132+
idcs = copy(md.idcs)
133+
vns = copy(md.vns)
134+
ranges = copy(md.ranges)
135+
vals = copy(md.vals)
136+
transforms = map(md.dists) do dist
137+
# TODO: Handle linked distributions.
138+
from_vec_transform(dist)
139+
end
140+
141+
return VarNameVector(
142+
OrderedDict{eltype(keys(idcs)),Int}(idcs), vns, ranges, vals, transforms
143+
)
144+
end
145+
146+
function VectorVarInfo(vi::UntypedVarInfo)
147+
md = metadata_to_varnamevector(vi.metadata)
148+
lp = getlogp(vi)
149+
return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi)))
150+
end
151+
152+
function VectorVarInfo(vi::TypedVarInfo)
153+
md = map(metadata_to_varnamevector, vi.metadata)
154+
lp = getlogp(vi)
155+
return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi)))
156+
end
157+
158+
"""
159+
has_varnamevector(varinfo::VarInfo)
160+
161+
Returns `true` if `varinfo` uses `VarNameVector` as metadata.
162+
"""
163+
has_varnamevector(vi) = false
164+
function has_varnamevector(vi::VarInfo)
165+
return vi.metadata isa VarNameVector ||
166+
(vi isa TypedVarInfo && any(Base.Fix2(isa, VarNameVector), values(vi.metadata)))
167+
end
168+
128169
function untyped_varinfo(
129170
rng::Random.AbstractRNG,
130171
model::Model,
@@ -321,6 +362,10 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
321362
)
322363
end
323364

365+
function merge_metadata(vnv_left::VarNameVector, vnv_right::VarNameVector)
366+
return merge(vnv_left, vnv_right)
367+
end
368+
324369
@generated function merge_metadata(
325370
metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right}
326371
) where {names_left,names_right}
@@ -513,9 +558,13 @@ Return the distribution from which `vn` was sampled in `vi`.
513558
"""
514559
getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn)
515560
getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)]
561+
# HACK: we shouldn't need this
562+
getdist(::VarNameVector, ::VarName) = nothing
516563

517564
getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn)
518565
getindex_internal(md::Metadata, vn::VarName) = view(md.vals, getrange(md, vn))
566+
# HACK: We shouldn't need this
567+
getindex_internal(vnv::VarNameVector, vn::VarName) = view(vnv.vals, getrange(vnv, vn))
519568

520569
function getindex_internal(vi::VarInfo, vns::Vector{<:VarName})
521570
return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns)
@@ -535,6 +584,9 @@ end
535584
function setval!(md::Metadata, val, vn::VarName)
536585
return md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val)
537586
end
587+
function setval!(vnv::VarNameVector, val, vn::VarName)
588+
return setindex_raw!(vnv, tovec(val), vn)
589+
end
538590

539591
"""
540592
getall(vi::VarInfo)
@@ -552,6 +604,7 @@ function getall(md::Metadata)
552604
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0)
553605
)
554606
end
607+
getall(vnv::VarNameVector) = vnv.vals
555608

556609
"""
557610
setall!(vi::VarInfo, val)
@@ -567,6 +620,12 @@ function _setall!(metadata::Metadata, val)
567620
metadata.vals[r] .= val[r]
568621
end
569622
end
623+
function _setall!(vnv::VarNameVector, val)
624+
# TODO: Do something more efficient here.
625+
for i in 1:length(vnv)
626+
vnv[i] = val[i]
627+
end
628+
end
570629
@generated function _setall!(metadata::NamedTuple{names}, val) where {names}
571630
expr = Expr(:block)
572631
start = :(1)
@@ -599,6 +658,10 @@ function settrans!!(metadata::Metadata, trans::Bool, vn::VarName)
599658

600659
return metadata
601660
end
661+
function settrans!!(vnv::VarNameVector, trans::Bool, vn::VarName)
662+
settrans!(vnv, trans, vn)
663+
return vnv
664+
end
602665

603666
function settrans!!(vi::VarInfo, trans::Bool)
604667
for vn in keys(vi)
@@ -940,6 +1003,8 @@ end
9401003

9411004
# X -> R for all variables associated with given sampler
9421005
function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model)
1006+
# If we're working with a `VarNameVector`, we always use immutable.
1007+
has_varnamevector(vi) && return link(t, vi, spl, model)
9431008
# Call `_link!` instead of `link!` to avoid deprecation warning.
9441009
_link!(vi, spl)
9451010
return vi
@@ -1035,6 +1100,8 @@ end
10351100
function invlink!!(
10361101
t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model
10371102
)
1103+
# If we're working with a `VarNameVector`, we always use immutable.
1104+
has_varnamevector(vi) && return invlink(t, vi, spl, model)
10381105
# Call `_invlink!` instead of `invlink!` to avoid deprecation warning.
10391106
_invlink!(vi, spl)
10401107
return vi
@@ -1260,6 +1327,66 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar
12601327
)
12611328
end
12621329

1330+
function _link_metadata!(
1331+
model::Model, varinfo::VarInfo, metadata::VarNameVector, target_vns
1332+
)
1333+
# HACK: We ignore `target_vns` here.
1334+
vns = keys(metadata)
1335+
# Need to extract the priors from the model.
1336+
dists = extract_priors(model, varinfo)
1337+
1338+
is_transformed = copy(metadata.is_transformed)
1339+
1340+
# Construct the linking transformations.
1341+
link_transforms = map(vns) do vn
1342+
# If `vn` is not part of `target_vns`, the `identity` transformation is used.
1343+
if (target_vns !== nothing && vn target_vns)
1344+
return identity
1345+
end
1346+
1347+
# Otherwise, we derive the transformation from the distribution.
1348+
is_transformed[getidx(metadata, vn)] = true
1349+
internal_to_linked_internal_transform(varinfo, vn, dists[vn])
1350+
end
1351+
# Compute the transformed values.
1352+
ys = map(vns, link_transforms) do vn, f
1353+
# TODO: Do we need to handle scenarios where `vn` is not in `dists`?
1354+
dist = dists[vn]
1355+
x = getindex_internal(metadata, vn)
1356+
y, logjac = with_logabsdet_jacobian(f, x)
1357+
# Accumulate the log-abs-det jacobian correction.
1358+
acclogp!!(varinfo, -logjac)
1359+
# Return the transformed value.
1360+
return y
1361+
end
1362+
# Extract the from-vec transformations.
1363+
fromvec_transforms = map(from_vec_transform, ys)
1364+
# Compose the transformations to form a full transformation from
1365+
# unconstrained vector representation to constrained space.
1366+
transforms = map(, map(inverse, link_transforms), fromvec_transforms)
1367+
# Convert to vector representation.
1368+
yvecs = map(tovec, ys)
1369+
1370+
# Determine new ranges.
1371+
ranges_new = similar(metadata.ranges)
1372+
offset = 0
1373+
for (i, v) in enumerate(yvecs)
1374+
r_start, r_end = offset + 1, length(v) + offset
1375+
offset = r_end
1376+
ranges_new[i] = r_start:r_end
1377+
end
1378+
1379+
# Now we just create a new metadata with the new `vals` and `ranges`.
1380+
return VarNameVector(
1381+
metadata.varname_to_index,
1382+
metadata.varnames,
1383+
ranges_new,
1384+
reduce(vcat, yvecs),
1385+
transforms,
1386+
is_transformed,
1387+
)
1388+
end
1389+
12631390
function invlink(
12641391
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
12651392
)
@@ -1360,6 +1487,55 @@ function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, targe
13601487
)
13611488
end
13621489

1490+
function _invlink_metadata!(
1491+
model::Model, varinfo::VarInfo, metadata::VarNameVector, target_vns
1492+
)
1493+
# HACK: We ignore `target_vns` here.
1494+
# TODO: Make use of `update!` to aovid copying values.
1495+
# => Only need to allocate for transformations.
1496+
1497+
vns = keys(metadata)
1498+
is_transformed = copy(metadata.is_transformed)
1499+
1500+
# Compute the transformed values.
1501+
xs = map(vns) do vn
1502+
f = gettransform(metadata, vn)
1503+
y = getindex_internal(metadata, vn)
1504+
# No need to use `with_reconstruct` as `f` will include this.
1505+
x, logjac = with_logabsdet_jacobian(f, y)
1506+
# Accumulate the log-abs-det jacobian correction.
1507+
acclogp!!(varinfo, -logjac)
1508+
# Mark as no longer transformed.
1509+
is_transformed[getidx(metadata, vn)] = false
1510+
# Return the transformed value.
1511+
return x
1512+
end
1513+
# Compose the transformations to form a full transformation from
1514+
# unconstrained vector representation to constrained space.
1515+
transforms = map(from_vec_transform, xs)
1516+
# Convert to vector representation.
1517+
xvecs = map(tovec, xs)
1518+
1519+
# Determine new ranges.
1520+
ranges_new = similar(metadata.ranges)
1521+
offset = 0
1522+
for (i, v) in enumerate(xvecs)
1523+
r_start, r_end = offset + 1, length(v) + offset
1524+
offset = r_end
1525+
ranges_new[i] = r_start:r_end
1526+
end
1527+
1528+
# Now we just create a new metadata with the new `vals` and `ranges`.
1529+
return VarNameVector(
1530+
metadata.varname_to_index,
1531+
metadata.varnames,
1532+
ranges_new,
1533+
reduce(vcat, xvecs),
1534+
transforms,
1535+
is_transformed,
1536+
)
1537+
end
1538+
13631539
"""
13641540
islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior})
13651541
@@ -1429,6 +1605,14 @@ function getindex(vi::VarInfo, vn::VarName, dist::Distribution)
14291605
val = getindex_internal(vi, vn)
14301606
return from_maybe_linked_internal(vi, vn, dist, val)
14311607
end
1608+
# HACK: Allows us to also work with `VarNameVector` where `dist` is not used,
1609+
# but we instead use a transformation stored with the variable.
1610+
function getindex(vi::VarInfo, vn::VarName, ::Nothing)
1611+
if !haskey(vi, vn)
1612+
throw(KeyError(vn))
1613+
end
1614+
return getmetadata(vi, vn)[vn]
1615+
end
14321616

14331617
function getindex(vi::VarInfo, vns::Vector{<:VarName})
14341618
vals_linked = mapreduce(vcat, vns) do vn
@@ -1621,6 +1805,11 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
16211805
return meta
16221806
end
16231807

1808+
function Base.push!(vnv::VarNameVector, vn, r, dist, gidset, num_produce)
1809+
f = from_vec_transform(dist)
1810+
return push!(vnv, vn, r, f)
1811+
end
1812+
16241813
"""
16251814
setorder!(vi::VarInfo, vn::VarName, index::Int)
16261815
@@ -1635,6 +1824,7 @@ function setorder!(metadata::Metadata, vn::VarName, index::Int)
16351824
metadata.orders[metadata.idcs[vn]] = index
16361825
return metadata
16371826
end
1827+
setorder!(vnv::VarNameVector, ::VarName, ::Int) = vnv
16381828

16391829
"""
16401830
getorder(vi::VarInfo, vn::VarName)
@@ -1660,6 +1850,8 @@ end
16601850
function is_flagged(metadata::Metadata, vn::VarName, flag::String)
16611851
return metadata.flags[flag][getidx(metadata, vn)]
16621852
end
1853+
# HACK: This is bad. Should we always return `true` here?
1854+
is_flagged(::VarNameVector, ::VarName, flag::String) = flag == "del" ? true : false
16631855

16641856
"""
16651857
unset_flag!(vi::VarInfo, vn::VarName, flag::String)
@@ -1674,6 +1866,7 @@ function unset_flag!(metadata::Metadata, vn::VarName, flag::String)
16741866
metadata.flags[flag][getidx(metadata, vn)] = false
16751867
return metadata
16761868
end
1869+
unset_flag!(vnv::VarNameVector, ::VarName, ::String) = vnv
16771870

16781871
"""
16791872
set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler)
@@ -2027,6 +2220,8 @@ function values_from_metadata(md::Metadata)
20272220
)
20282221
end
20292222

2223+
values_from_metadata(md::VarNameVector) = pairs(md)
2224+
20302225
# Transforming from internal representation to distribution representation.
20312226
# Without `dist` argument: base on `dist` extracted from self.
20322227
function from_internal_transform(vi::VarInfo, vn::VarName)
@@ -2035,11 +2230,17 @@ end
20352230
function from_internal_transform(md::Metadata, vn::VarName)
20362231
return from_internal_transform(md, vn, getdist(md, vn))
20372232
end
2233+
function from_internal_transform(md::VarNameVector, vn::VarName)
2234+
return gettransform(md, vn)
2235+
end
20382236
# With both `vn` and `dist` arguments: base on provided `dist`.
20392237
function from_internal_transform(vi::VarInfo, vn::VarName, dist)
20402238
return from_internal_transform(getmetadata(vi, vn), vn, dist)
20412239
end
20422240
from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist)
2241+
function from_internal_transform(::VarNameVector, ::VarName, dist)
2242+
return from_vec_transform(dist)
2243+
end
20432244

20442245
# Without `dist` argument: base on `dist` extracted from self.
20452246
function from_linked_internal_transform(vi::VarInfo, vn::VarName)
@@ -2048,6 +2249,9 @@ end
20482249
function from_linked_internal_transform(md::Metadata, vn::VarName)
20492250
return from_linked_internal_transform(md, vn, getdist(md, vn))
20502251
end
2252+
function from_linked_internal_transform(md::VarNameVector, vn::VarName)
2253+
return gettransform(md, vn)
2254+
end
20512255
# With both `vn` and `dist` arguments: base on provided `dist`.
20522256
function from_linked_internal_transform(vi::VarInfo, vn::VarName, dist)
20532257
# Dispatch to metadata in case this alters the behavior.
@@ -2056,3 +2260,6 @@ end
20562260
function from_linked_internal_transform(::Metadata, ::VarName, dist)
20572261
return from_linked_vec_transform(dist)
20582262
end
2263+
function from_linked_internal_transform(::VarNameVector, ::VarName, dist)
2264+
return from_linked_vec_transform(dist)
2265+
end

0 commit comments

Comments
 (0)