Skip to content

Commit 90aef0b

Browse files
committed
added proper testing for other VarInfo types
1 parent bdcc69f commit 90aef0b

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

test/varinfo.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -814,18 +814,26 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
814814
end
815815
end
816816

817-
@testset "getranges" begin
817+
# NOTE: It is not yet clear if this is something we want from all varinfo types.
818+
# Hence, we only test the `VarInfo` types here.
819+
@testset "getranges for `VarInfo`" begin
818820
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
819821
vns = DynamicPPL.TestUtils.varnames(model)
820-
varinfo = DynamicPPL.typed_varinfo(model)
821-
x = values_as(varinfo, Vector)
822-
823-
# Let's just check all the subsets of `vns`.
824-
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in combinations(vns)
825-
ranges = DynamicPPL.getranges(varinfo, vns_subset)
826-
@test length(ranges) == length(vns_subset)
827-
for (r, vn) in zip(ranges, vns_subset)
828-
@test x[r] == DynamicPPL.tovec(varinfo[vn])
822+
nt = DynamicPPL.TestUtils.rand_prior_true(model)
823+
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns)
824+
# Only keep `VarInfo` types.
825+
varinfos = filter(Base.Fix2(isa, VarInfo), varinfos)
826+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
827+
x = values_as(varinfo, Vector)
828+
829+
# Let's just check all the subsets of `vns`.
830+
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in
831+
combinations(vns)
832+
ranges = DynamicPPL.getranges(varinfo, vns_subset)
833+
@test length(ranges) == length(vns_subset)
834+
for (r, vn) in zip(ranges, vns_subset)
835+
@test x[r] == DynamicPPL.tovec(varinfo[vn])
836+
end
829837
end
830838
end
831839
end

0 commit comments

Comments
 (0)