Skip to content

Commit 8afe681

Browse files
committed
separated the getrange version which returns the range of the vecto
representaiton rather than the internal representaiton into `vector_getrange` to make its function explicit
1 parent f500c23 commit 8afe681

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

src/threadsafe.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:
178178
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns)
179179
end
180180

181+
vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
182+
vector_getrange(vi::ThreadSafeVarInfo) = vector_getrange(vi.varinfo)
183+
vector_getranges(vi::ThreadSafeVarInfo) = vector_getranges(vi.varinfo)
184+
181185
function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
182186
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
183187
end

src/varinfo.jl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ end
203203
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
204204

205205

206+
"""
207+
vector_length(varinfo::VarInfo)
208+
209+
Return the length of the vector representation of `varinfo`.
210+
"""
206211
vector_length(varinfo::VarInfo) = length(varinfo.metadata)
207212
vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata)
208213
vector_length(md::Metadata) = sum(length, md.ranges)
@@ -615,21 +620,6 @@ getidx(md::Metadata, vn::VarName) = md.idcs[vn]
615620
Return the index range of `vn` in the metadata of `vi`.
616621
"""
617622
getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn)
618-
# For `TypedVarInfo` it's more difficult since we need to keep track of the offset.
619-
# TOOD: Should we unroll this using `@generated`?
620-
function getrange(vi::TypedVarInfo, vn::VarName)
621-
offset = 0
622-
for md in values(vi.metadata)
623-
# First, we need to check if `vn` is in `md`.
624-
# In this case, we can just return the corresponding range + offset.
625-
haskey(md, vn) && return getrange(md, vn) .+ offset
626-
# Otherwise, we need to get the cumulative length of the ranges in `md`
627-
# and add it to the offset.
628-
offset += sum(length, md.ranges)
629-
end
630-
# If we reach this point, `vn` is not in `vi.metadata`.
631-
throw(KeyError(vn))
632-
end
633623
getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)]
634624

635625
"""
@@ -648,8 +638,38 @@ Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
648638
function getranges(vi::VarInfo, vns::Vector{<:VarName})
649639
return map(Base.Fix1(getrange, vi), vns)
650640
end
651-
# A more efficient version for `TypedVarInfo`.
652-
function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName})
641+
642+
"""
643+
vector_getrange(varinfo::VarInfo, varname::VarName)
644+
645+
Return the range corresponding to `varname` in the vector representation of `varinfo`.
646+
"""
647+
vector_getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn)
648+
function vector_getrange(vi::TypedVarInfo, vn::VarName)
649+
offset = 0
650+
for md in values(vi.metadata)
651+
# First, we need to check if `vn` is in `md`.
652+
# In this case, we can just return the corresponding range + offset.
653+
haskey(md, vn) && return vector_getrange(md, vn) .+ offset
654+
# Otherwise, we need to get the cumulative length of the ranges in `md`
655+
# and add it to the offset.
656+
offset += sum(length, md.ranges)
657+
end
658+
# If we reach this point, `vn` is not in `vi.metadata`.
659+
throw(KeyError(vn))
660+
end
661+
vector_getrange(md::Metadata, vn::VarName) = getrange(md, vn)
662+
663+
"""
664+
vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})
665+
666+
Return the range corresponding to `varname` in the vector representation of `varinfo`.
667+
"""
668+
function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName})
669+
return map(Base.Fix1(vector_getrange, varinfo), varname)
670+
end
671+
# Specialized version for `TypedVarInfo`.
672+
function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName})
653673
# TODO: Does it help if we _don't_ convert to a vector here?
654674
metadatas = collect(values(varinfo.metadata))
655675
# Extract the offsets.
@@ -672,6 +692,7 @@ function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.Va
672692
return ranges
673693
end
674694

695+
675696
"""
676697
getdist(vi::VarInfo, vn::VarName)
677698

test/varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
816816

817817
# NOTE: It is not yet clear if this is something we want from all varinfo types.
818818
# Hence, we only test the `VarInfo` types here.
819-
@testset "getranges for `VarInfo`" begin
819+
@testset "vector_getranges for `VarInfo`" begin
820820
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
821821
vns = DynamicPPL.TestUtils.varnames(model)
822822
nt = DynamicPPL.TestUtils.rand_prior_true(model)
@@ -829,7 +829,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
829829
# Let's just check all the subsets of `vns`.
830830
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in
831831
combinations(vns)
832-
ranges = DynamicPPL.getranges(varinfo, vns_subset)
832+
ranges = DynamicPPL.vector_getranges(varinfo, vns_subset)
833833
@test length(ranges) == length(vns_subset)
834834
for (r, vn) in zip(ranges, vns_subset)
835835
@test x[r] == DynamicPPL.tovec(varinfo[vn])

0 commit comments

Comments
 (0)