@@ -202,6 +202,11 @@ function VarInfo(
202
202
end
203
203
VarInfo (model:: Model , args... ) = VarInfo (Random. default_rng (), model, args... )
204
204
205
+
206
+ Base. length (varinfo:: VarInfo ) = length (varinfo. metadata)
207
+ Base. length (varinfo:: TypedVarInfo ) = sum (length, varinfo. metadata)
208
+ Base. length (md:: Metadata ) = sum (length, md. ranges)
209
+
205
210
unflatten (vi:: VarInfo , x:: AbstractVector ) = unflatten (vi, SampleFromPrior (), x)
206
211
207
212
# TODO : deprecate.
@@ -643,6 +648,29 @@ Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
643
648
function getranges (vi:: VarInfo , vns:: Vector{<:VarName} )
644
649
return mapreduce (Base. Fix1 (getrange, vi), vcat, vns; init= Int[])
645
650
end
651
+ # A more efficient version for `TypedVarInfo`.
652
+ function getranges (varinfo:: DynamicPPL.TypedVarInfo , vns:: Vector{<:DynamicPPL.VarName} )
653
+ # TODO : Does it help if we _don't_ convert to a vector here?
654
+ metadatas = collect (values (varinfo. metadata))
655
+ # Extract the offsets.
656
+ offsets = cumsum (map (length, metadatas))
657
+ # Extract the ranges from each metadata.
658
+ ranges = Vector {UnitRange{Int}} (undef, length (vns))
659
+ for (i, metadata) in enumerate (metadatas)
660
+ vns_metadata = filter (Base. Fix1 (haskey, metadata), vns)
661
+ # If none of the variables exist in the metadata, we return an empty array.
662
+ isempty (vns_metadata) && continue
663
+ # Otherwise, we extract the ranges.
664
+ offset = i == 1 ? 0 : offsets[i - 1 ]
665
+ for vn in vns_metadata
666
+ r_vn = getrange (metadata, vn)
667
+ # Get the index, so we return in the same order as `vns`.
668
+ idx = findfirst (== (vn), vns)
669
+ ranges[idx] = r_vn .+ offset
670
+ end
671
+ end
672
+ return ranges
673
+ end
646
674
647
675
"""
648
676
getdist(vi::VarInfo, vn::VarName)
0 commit comments