203
203
VarInfo (model:: Model , args... ) = VarInfo (Random. default_rng (), model, args... )
204
204
205
205
206
+ """
207
+ vector_length(varinfo::VarInfo)
208
+
209
+ Return the length of the vector representation of `varinfo`.
210
+ """
206
211
vector_length (varinfo:: VarInfo ) = length (varinfo. metadata)
207
212
vector_length (varinfo:: TypedVarInfo ) = sum (length, varinfo. metadata)
208
213
vector_length (md:: Metadata ) = sum (length, md. ranges)
@@ -615,21 +620,6 @@ getidx(md::Metadata, vn::VarName) = md.idcs[vn]
615
620
Return the index range of `vn` in the metadata of `vi`.
616
621
"""
617
622
getrange (vi:: VarInfo , vn:: VarName ) = getrange (getmetadata (vi, vn), vn)
618
- # For `TypedVarInfo` it's more difficult since we need to keep track of the offset.
619
- # TOOD: Should we unroll this using `@generated`?
620
- function getrange (vi:: TypedVarInfo , vn:: VarName )
621
- offset = 0
622
- for md in values (vi. metadata)
623
- # First, we need to check if `vn` is in `md`.
624
- # In this case, we can just return the corresponding range + offset.
625
- haskey (md, vn) && return getrange (md, vn) .+ offset
626
- # Otherwise, we need to get the cumulative length of the ranges in `md`
627
- # and add it to the offset.
628
- offset += sum (length, md. ranges)
629
- end
630
- # If we reach this point, `vn` is not in `vi.metadata`.
631
- throw (KeyError (vn))
632
- end
633
623
getrange (md:: Metadata , vn:: VarName ) = md. ranges[getidx (md, vn)]
634
624
635
625
"""
@@ -648,8 +638,38 @@ Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
648
638
function getranges (vi:: VarInfo , vns:: Vector{<:VarName} )
649
639
return map (Base. Fix1 (getrange, vi), vns)
650
640
end
651
- # A more efficient version for `TypedVarInfo`.
652
- function getranges (varinfo:: DynamicPPL.TypedVarInfo , vns:: Vector{<:DynamicPPL.VarName} )
641
+
642
+ """
643
+ vector_getrange(varinfo::VarInfo, varname::VarName)
644
+
645
+ Return the range corresponding to `varname` in the vector representation of `varinfo`.
646
+ """
647
+ vector_getrange (vi:: VarInfo , vn:: VarName ) = getrange (getmetadata (vi, vn), vn)
648
+ function vector_getrange (vi:: TypedVarInfo , vn:: VarName )
649
+ offset = 0
650
+ for md in values (vi. metadata)
651
+ # First, we need to check if `vn` is in `md`.
652
+ # In this case, we can just return the corresponding range + offset.
653
+ haskey (md, vn) && return vector_getrange (md, vn) .+ offset
654
+ # Otherwise, we need to get the cumulative length of the ranges in `md`
655
+ # and add it to the offset.
656
+ offset += sum (length, md. ranges)
657
+ end
658
+ # If we reach this point, `vn` is not in `vi.metadata`.
659
+ throw (KeyError (vn))
660
+ end
661
+ vector_getrange (md:: Metadata , vn:: VarName ) = getrange (md, vn)
662
+
663
+ """
664
+ vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})
665
+
666
+ Return the range corresponding to `varname` in the vector representation of `varinfo`.
667
+ """
668
+ function vector_getranges (varinfo:: VarInfo , varname:: Vector{<:VarName} )
669
+ return map (Base. Fix1 (vector_getrange, varinfo), varname)
670
+ end
671
+ # Specialized version for `TypedVarInfo`.
672
+ function vector_getranges (varinfo:: TypedVarInfo , vns:: Vector{<:VarName} )
653
673
# TODO : Does it help if we _don't_ convert to a vector here?
654
674
metadatas = collect (values (varinfo. metadata))
655
675
# Extract the offsets.
@@ -672,6 +692,7 @@ function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.Va
672
692
return ranges
673
693
end
674
694
695
+
675
696
"""
676
697
getdist(vi::VarInfo, vn::VarName)
677
698
0 commit comments