@@ -208,6 +208,15 @@ function VarInfo(
208
208
end
209
209
VarInfo (model:: Model , args... ) = VarInfo (Random. default_rng (), model, args... )
210
210
211
+ """
212
+ vector_length(varinfo::VarInfo)
213
+
214
+ Return the length of the vector representation of `varinfo`.
215
+ """
216
+ vector_length (varinfo:: VarInfo ) = length (varinfo. metadata)
217
+ vector_length (varinfo:: TypedVarInfo ) = sum (length, varinfo. metadata)
218
+ vector_length (md:: Metadata ) = sum (length, md. ranges)
219
+
211
220
unflatten (vi:: VarInfo , x:: AbstractVector ) = unflatten (vi, SampleFromPrior (), x)
212
221
213
222
# TODO : deprecate.
@@ -632,7 +641,72 @@ setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range
632
641
Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
633
642
"""
634
643
function getranges (vi:: VarInfo , vns:: Vector{<:VarName} )
635
- return mapreduce (vn -> getrange (vi, vn), vcat, vns; init= Int[])
644
+ return map (Base. Fix1 (getrange, vi), vns)
645
+ end
646
+
647
+ """
648
+ vector_getrange(varinfo::VarInfo, varname::VarName)
649
+
650
+ Return the range corresponding to `varname` in the vector representation of `varinfo`.
651
+ """
652
+ vector_getrange (vi:: VarInfo , vn:: VarName ) = getrange (vi. metadata, vn)
653
+ function vector_getrange (vi:: TypedVarInfo , vn:: VarName )
654
+ offset = 0
655
+ for md in values (vi. metadata)
656
+ # First, we need to check if `vn` is in `md`.
657
+ # In this case, we can just return the corresponding range + offset.
658
+ haskey (md, vn) && return getrange (md, vn) .+ offset
659
+ # Otherwise, we need to get the cumulative length of the ranges in `md`
660
+ # and add it to the offset.
661
+ offset += sum (length, md. ranges)
662
+ end
663
+ # If we reach this point, `vn` is not in `vi.metadata`.
664
+ throw (KeyError (vn))
665
+ end
666
+
667
+ """
668
+ vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})
669
+
670
+ Return the range corresponding to `varname` in the vector representation of `varinfo`.
671
+ """
672
+ function vector_getranges (varinfo:: VarInfo , varname:: Vector{<:VarName} )
673
+ return map (Base. Fix1 (vector_getrange, varinfo), varname)
674
+ end
675
+ # Specialized version for `TypedVarInfo`.
676
+ function vector_getranges (varinfo:: TypedVarInfo , vns:: Vector{<:VarName} )
677
+ # TODO : Does it help if we _don't_ convert to a vector here?
678
+ metadatas = collect (values (varinfo. metadata))
679
+ # Extract the offsets.
680
+ offsets = cumsum (map (vector_length, metadatas))
681
+ # Extract the ranges from each metadata.
682
+ ranges = Vector {UnitRange{Int}} (undef, length (vns))
683
+ # Need to keep track of which ones we've seen.
684
+ not_seen = fill (true , length (vns))
685
+ for (i, metadata) in enumerate (metadatas)
686
+ vns_metadata = filter (Base. Fix1 (haskey, metadata), vns)
687
+ # If none of the variables exist in the metadata, we return an empty array.
688
+ isempty (vns_metadata) && continue
689
+ # Otherwise, we extract the ranges.
690
+ offset = i == 1 ? 0 : offsets[i - 1 ]
691
+ for vn in vns_metadata
692
+ r_vn = getrange (metadata, vn)
693
+ # Get the index, so we return in the same order as `vns`.
694
+ # NOTE: There might be duplicates in `vns`, so we need to handle that.
695
+ indices = findall (== (vn), vns)
696
+ for idx in indices
697
+ not_seen[idx] = false
698
+ ranges[idx] = r_vn .+ offset
699
+ end
700
+ end
701
+ end
702
+ # Raise key error if any of the variables were not found.
703
+ if any (not_seen)
704
+ inds = findall (not_seen)
705
+ # Just use a `convert` to get the same type as the input; don't want to confuse by overly
706
+ # specilizing the types in the error message.
707
+ throw (KeyError (convert (typeof (vns), vns[inds])))
708
+ end
709
+ return ranges
636
710
end
637
711
638
712
"""
@@ -1320,13 +1394,13 @@ end
1320
1394
1321
1395
function _inner_transform! (md:: Metadata , vi:: VarInfo , vn:: VarName , f)
1322
1396
# TODO : Use inplace versions to avoid allocations
1323
- yvec, logjac = with_logabsdet_jacobian (f, getindex_internal (vi , vn))
1397
+ yvec, logjac = with_logabsdet_jacobian (f, getindex_internal (md , vn))
1324
1398
# Determine the new range.
1325
- start = first (getrange (vi , vn))
1399
+ start = first (getrange (md , vn))
1326
1400
# NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`.
1327
- setrange! (vi , vn, start: (start + length (yvec) - 1 ))
1401
+ setrange! (md , vn, start: (start + length (yvec) - 1 ))
1328
1402
# Set the new value.
1329
- setval! (vi , yvec, vn)
1403
+ setval! (md , yvec, vn)
1330
1404
acclogp!! (vi, - logjac)
1331
1405
return vi
1332
1406
end
0 commit comments