Skip to content

Commit 23d561f

Browse files
committed
added length implementation for VarInfo and Metadata
1 parent b1b8a00 commit 23d561f

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/varinfo.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ function VarInfo(
202202
end
203203
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
204204

205+
206+
Base.length(varinfo::VarInfo) = length(varinfo.metadata)
207+
Base.length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata)
208+
Base.length(md::Metadata) = sum(length, md.ranges)
209+
205210
unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x)
206211

207212
# TODO: deprecate.
@@ -643,6 +648,29 @@ Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
643648
function getranges(vi::VarInfo, vns::Vector{<:VarName})
644649
return mapreduce(Base.Fix1(getrange, vi), vcat, vns; init=Int[])
645650
end
651+
# A more efficient version for `TypedVarInfo`.
652+
function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName})
653+
# TODO: Does it help if we _don't_ convert to a vector here?
654+
metadatas = collect(values(varinfo.metadata))
655+
# Extract the offsets.
656+
offsets = cumsum(map(length, metadatas))
657+
# Extract the ranges from each metadata.
658+
ranges = Vector{UnitRange{Int}}(undef, length(vns))
659+
for (i, metadata) in enumerate(metadatas)
660+
vns_metadata = filter(Base.Fix1(haskey, metadata), vns)
661+
# If none of the variables exist in the metadata, we return an empty array.
662+
isempty(vns_metadata) && continue
663+
# Otherwise, we extract the ranges.
664+
offset = i == 1 ? 0 : offsets[i - 1]
665+
for vn in vns_metadata
666+
r_vn = getrange(metadata, vn)
667+
# Get the index, so we return in the same order as `vns`.
668+
idx = findfirst(==(vn), vns)
669+
ranges[idx] = r_vn .+ offset
670+
end
671+
end
672+
return ranges
673+
end
646674

647675
"""
648676
getdist(vi::VarInfo, vn::VarName)

0 commit comments

Comments
 (0)