Skip to content

Commit 99f5695

Browse files
committed
Add some tests
1 parent b3b97da commit 99f5695

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

test/pobserve_macro.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
module DynamicPPLPobserveMacroTests
2+
3+
using DynamicPPL, Distributions, Test
4+
5+
@testset verbose = true "pobserve_macro.jl" begin
6+
@testset "loglikelihood is correctly accumulated" begin
7+
@model function f(x)
8+
@pobserve for i in eachindex(x)
9+
x[i] ~ Normal()
10+
end
11+
end
12+
x = randn(3)
13+
expected_loglike = loglikelihood(Normal(), x)
14+
vi = VarInfo(f(x))
15+
@test isapprox(DynamicPPL.getloglikelihood(vi), expected_loglike)
16+
end
17+
18+
@testset "return values are correct" begin
19+
@testset "single expression at the end" begin
20+
@model function f(x)
21+
xplusone = @pobserve for i in eachindex(x)
22+
x[i] ~ Normal()
23+
x[i] + 1
24+
end
25+
return xplusone
26+
end
27+
x = randn(3)
28+
@test f(x)() == x .+ 1
29+
30+
@testset "calculations are not repeated" begin
31+
# Make sure that the final expression inside pobserve is not evaluated
32+
# multiple times.
33+
counter = 0
34+
increment_and_return(y) = (counter += 1; y)
35+
@model function g(x)
36+
xs = @pobserve for i in eachindex(x)
37+
x[i] ~ Normal()
38+
increment_and_return(x[i])
39+
end
40+
return xs
41+
end
42+
x = randn(3)
43+
@test g(x)() == x
44+
@test counter == length(x)
45+
end
46+
end
47+
48+
@testset "tilde expression at the end" begin
49+
@model function f(x)
50+
xs = @pobserve for i in eachindex(x)
51+
# This should behave as if it returns `x[i]`
52+
x[i] ~ Normal()
53+
end
54+
return xs
55+
end
56+
x = randn(3)
57+
@test f(x)() == x
58+
end
59+
end
60+
end
61+
62+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ include("test_util.jl")
5757
include("utils.jl")
5858
include("accumulators.jl")
5959
include("compiler.jl")
60+
include("pobserve_macro.jl")
6061
include("varnamedvector.jl")
6162
include("varinfo.jl")
6263
include("simple_varinfo.jl")

0 commit comments

Comments
 (0)