Skip to content

Commit 916b8d5

Browse files
committed
Add tests
1 parent 9df4914 commit 916b8d5

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,6 @@ module DynamicPPLMCMCChainsExtTests
33
using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC
44

55
@testset "DynamicPPLMCMCChainsExt" begin
6-
@model demo() = x ~ Normal()
7-
model = demo()
8-
9-
chain = MCMCChains.Chains(
10-
randn(1000, 2, 1),
11-
[:x, :y],
12-
Dict(:internals => [:y]);
13-
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
14-
)
15-
chain_generated = @test_nowarn returned(model, chain)
16-
@test size(chain_generated) == (1000, 1)
17-
@test mean(chain_generated) 0 atol = 0.1
18-
196
@testset "from_samples" begin
207
@model function f(z)
218
x ~ Normal()
@@ -61,6 +48,42 @@ using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC
6148
@test new_p.stats == p.stats
6249
end
6350
end
51+
52+
@testset "returned (basic)" begin
53+
@model demo() = x ~ Normal()
54+
model = demo()
55+
56+
chain = MCMCChains.Chains(
57+
randn(1000, 2, 1),
58+
[:x, :y],
59+
Dict(:internals => [:y]);
60+
info=(; varname_to_symbol=Dict(@varname(x) => :x)),
61+
)
62+
chain_generated = @test_nowarn returned(model, chain)
63+
@test size(chain_generated) == (1000, 1)
64+
@test mean(chain_generated) 0 atol = 0.1
65+
end
66+
67+
@testset "returned: errors on missing variable" begin
68+
# Create a chain that only has `m`.
69+
@model function m_only()
70+
return m ~ Normal()
71+
end
72+
model_m_only = m_only()
73+
chain_m_only = AbstractMCMC.from_samples(
74+
MCMCChains.Chains,
75+
hcat([ParamsWithStats(VarInfo(model_m_only), model_m_only) for _ in 1:50]),
76+
)
77+
78+
# Define a model that needs both `m` and `s`.
79+
@model function f()
80+
m ~ Normal()
81+
s ~ Exponential()
82+
return y ~ Normal(m, s)
83+
end
84+
model = f() | (; y=1.0)
85+
@test_throws "No value was provided" returned(model, chain_m_only)
86+
end
6487
end
6588

6689
# test for `predict` is in `test/model.jl`

test/pointwise_logdensities.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,29 @@ end
9494
@test logprior logprior_true
9595
@test loglikelihood loglikelihood_true
9696
end
97+
98+
@testset "errors when variables are missing" begin
99+
# Create a chain that only has `m`.
100+
@model function m_only()
101+
return m ~ Normal()
102+
end
103+
model_m_only = m_only()
104+
chain_m_only = AbstractMCMC.from_samples(
105+
MCMCChains.Chains,
106+
hcat([ParamsWithStats(VarInfo(model_m_only), model_m_only) for _ in 1:50]),
107+
)
108+
109+
# Define a model that needs both `m` and `s`.
110+
@model function f()
111+
m ~ Normal()
112+
s ~ Exponential()
113+
return y ~ Normal(m, s)
114+
end
115+
model = f() | (; y=1.0)
116+
@test_throws "No value was provided" pointwise_logdensities(model, chain_m_only)
117+
@test_throws "No value was provided" pointwise_loglikelihoods(model, chain_m_only)
118+
@test_throws "No value was provided" pointwise_prior_logdensities(
119+
model, chain_m_only
120+
)
121+
end
97122
end

0 commit comments

Comments
 (0)