Skip to content

Commit f90d451

Browse files
committed
Add tests, fix import order
1 parent 66423bf commit f90d451

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,14 @@ include("logdensityfunction.jl")
194194
include("model_utils.jl")
195195
include("extract_priors.jl")
196196
include("values_as_in_model.jl")
197+
include("experimental.jl")
197198
include("chains.jl")
198199
include("bijector.jl")
199200

200201
include("debug_utils.jl")
201202
using .DebugUtils
202203
include("test_utils.jl")
203204

204-
include("experimental.jl")
205205
include("deprecated.jl")
206206

207207
if isdefined(Base.Experimental, :register_error_hint)

src/chains.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function ParamsWithStats(
109109
else
110110
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
111111
end
112-
return _, varinfo = DynamicPPL.Experimental.fast_evaluate!!(
112+
_, varinfo = DynamicPPL.Experimental.fast_evaluate!!(
113113
ldf.model, ctx, AccumulatorTuple(accs)
114114
)
115115
params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values

test/chains.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DynamicPPL
44
using Distributions
55
using Test
66

7-
@testset "ParamsWithStats" begin
7+
@testset "ParamsWithStats, from VarInfo" begin
88
@model function f(z)
99
x ~ Normal()
1010
y := x + 1
@@ -66,4 +66,31 @@ using Test
6666
end
6767
end
6868

69+
@testset "ParamsWithStats from FastLDF" begin
70+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
71+
unlinked_vi = VarInfo(m)
72+
@testset "$islinked" for islinked in (false, true)
73+
vi = if islinked
74+
DynamicPPL.link!!(unlinked_vi, m)
75+
else
76+
unlinked_vi
77+
end
78+
nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi)
79+
params = map(identity, vi[:])
80+
81+
# Get the ParamsWithStats using FastLDF
82+
fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi)
83+
ps = ParamsWithStats(params, fldf)
84+
85+
# Check that length of parameters is as expected
86+
@test length(ps.params) == length(keys(vi))
87+
88+
# Iterate over all variables to check that their values match
89+
for vn in keys(vi)
90+
@test ps.params[vn] == vi[vn]
91+
end
92+
end
93+
end
94+
end
95+
6996
end # module

0 commit comments

Comments
 (0)