Skip to content

Commit cd78d24

Browse files
committed
added handling of missing indices + tests for these cases
1 parent 2734070 commit cd78d24

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

src/varinfo.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,13 @@ function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName})
687687
ranges[idx] = r_vn .+ offset
688688
end
689689
end
690+
# Raise key error if any of the variables were not found.
691+
if any(!isassigned, ranges)
692+
inds = findall(!isassigned, ranges)
693+
# Just use a `convert` to get the same type as the input; don't want to confuse by overly
694+
# specilizing the types in the error message.
695+
throw(KeyError(convert(typeof(vns), vns[inds])))
696+
end
690697
return ranges
691698
end
692699

test/varinfo.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,9 +820,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
820820
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
821821
vns = DynamicPPL.TestUtils.varnames(model)
822822
nt = DynamicPPL.TestUtils.rand_prior_true(model)
823-
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns)
823+
varinfos = DynamicPPL.TestUtils.setup_varinfos(
824+
model, nt, vns; include_threadsafe=true
825+
)
824826
# Only keep `VarInfo` types.
825-
varinfos = filter(Base.Fix2(isa, VarInfo), varinfos)
827+
varinfos = filter(
828+
Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos
829+
)
826830
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
827831
x = values_as(varinfo, Vector)
828832

@@ -835,6 +839,16 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
835839
@test x[r] == DynamicPPL.tovec(varinfo[vn])
836840
end
837841
end
842+
843+
# Let's try some failure cases.
844+
@test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[]
845+
# Non-existent variables.
846+
@test_throws KeyError DynamicPPL.vector_getranges(
847+
varinfo, [VarName{gensym("vn")}()]
848+
)
849+
@test_throws KeyError DynamicPPL.vector_getranges(
850+
varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()]
851+
)
838852
end
839853
end
840854
end

0 commit comments

Comments
 (0)