Skip to content

Commit 992569f

Browse files
committed
Fix tests
1 parent 6c776e9 commit 992569f

File tree

2 files changed

+2
-57
lines changed

2 files changed

+2
-57
lines changed

test/lkj.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ _lkj_atol = 0.05
2121

2222
@testset "Sample from x ~ LKJ(2, 1)" begin
2323
model = lkj_prior_demo()
24-
for init_strategy in [PriorInit(), UniformInit()]
24+
for init_strategy in [InitFromPrior(), InitFromUniform()]
2525
samples = [
2626
last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples
2727
]
@@ -32,8 +32,7 @@ end
3232

3333
@testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L']
3434
model = lkj_chol_prior_demo(uplo)
35-
# `SampleFromPrior` will sample in unconstrained space.
36-
for init_strategy in [PriorInit(), UniformInit()]
35+
for init_strategy in [InitFromPrior(), InitFromUniform()]
3736
samples = [
3837
last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples
3938
]

test/sampler.jl

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -113,60 +113,6 @@
113113
end
114114
end
115115

116-
@testset "SampleFromPrior and SampleUniform" begin
117-
@model function gdemo(x, y)
118-
s ~ InverseGamma(2, 3)
119-
m ~ Normal(2.0, sqrt(s))
120-
x ~ Normal(m, sqrt(s))
121-
return y ~ Normal(m, sqrt(s))
122-
end
123-
124-
model = gdemo(1.0, 2.0)
125-
N = 1_000
126-
127-
chains = sample(model, SampleFromPrior(), N; progress=false)
128-
@test chains isa Vector{<:VarInfo}
129-
@test length(chains) == N
130-
131-
# Expected value of ``X`` where ``X ~ N(2, ...)`` is 2.
132-
@test mean(vi[@varname(m)] for vi in chains) 2 atol = 0.15
133-
134-
# Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3.
135-
@test mean(vi[@varname(s)] for vi in chains) 3 atol = 0.2
136-
137-
chains = sample(model, SampleFromUniform(), N; progress=false)
138-
@test chains isa Vector{<:VarInfo}
139-
@test length(chains) == N
140-
141-
# `m` is Gaussian, i.e. no transformation is used, so it
142-
# will be drawn from U[-2, 2] and its mean should be 0.
143-
@test mean(vi[@varname(m)] for vi in chains) 0.0 atol = 0.1
144-
145-
# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
146-
@test mean(vi[@varname(s)] for vi in chains) 1.8 atol = 0.1
147-
end
148-
149-
@testset "init" begin
150-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
151-
N = 1000
152-
chain_init = sample(model, SampleFromUniform(), N; progress=false)
153-
154-
for vn in keys(first(chain_init))
155-
if AbstractPPL.subsumes(@varname(s), vn)
156-
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
157-
dist = InverseGamma(2, 3)
158-
b = DynamicPPL.link_transform(dist)
159-
@test mean(mean(b(vi[vn])) for vi in chain_init) 0 atol = 0.11
160-
elseif AbstractPPL.subsumes(@varname(m), vn)
161-
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
162-
@test mean(mean(vi[vn]) for vi in chain_init) 0 atol = 0.11
163-
else
164-
error("Unknown variable name: $vn")
165-
end
166-
end
167-
end
168-
end
169-
170116
@testset "Initial parameters" begin
171117
# dummy algorithm that just returns initial value and does not perform any sampling
172118
abstract type OnlyInitAlg end

0 commit comments

Comments
 (0)