File tree Expand file tree Collapse file tree 2 files changed +63
-0
lines changed Expand file tree Collapse file tree 2 files changed +63
-0
lines changed Original file line number Diff line number Diff line change
1
+ module DynamicPPLPobserveMacroTests
2
+
3
+ using DynamicPPL, Distributions, Test
4
+
5
+ @testset verbose = true " pobserve_macro.jl" begin
6
+ @testset " loglikelihood is correctly accumulated" begin
7
+ @model function f (x)
8
+ @pobserve for i in eachindex (x)
9
+ x[i] ~ Normal ()
10
+ end
11
+ end
12
+ x = randn (3 )
13
+ expected_loglike = loglikelihood (Normal (), x)
14
+ vi = VarInfo (f (x))
15
+ @test isapprox (DynamicPPL. getloglikelihood (vi), expected_loglike)
16
+ end
17
+
18
+ @testset " return values are correct" begin
19
+ @testset " single expression at the end" begin
20
+ @model function f (x)
21
+ xplusone = @pobserve for i in eachindex (x)
22
+ x[i] ~ Normal ()
23
+ x[i] + 1
24
+ end
25
+ return xplusone
26
+ end
27
+ x = randn (3 )
28
+ @test f (x)() == x .+ 1
29
+
30
+ @testset " calculations are not repeated" begin
31
+ # Make sure that the final expression inside pobserve is not evaluated
32
+ # multiple times.
33
+ counter = 0
34
+ increment_and_return (y) = (counter += 1 ; y)
35
+ @model function g (x)
36
+ xs = @pobserve for i in eachindex (x)
37
+ x[i] ~ Normal ()
38
+ increment_and_return (x[i])
39
+ end
40
+ return xs
41
+ end
42
+ x = randn (3 )
43
+ @test g (x)() == x
44
+ @test counter == length (x)
45
+ end
46
+ end
47
+
48
+ @testset " tilde expression at the end" begin
49
+ @model function f (x)
50
+ xs = @pobserve for i in eachindex (x)
51
+ # This should behave as if it returns `x[i]`
52
+ x[i] ~ Normal ()
53
+ end
54
+ return xs
55
+ end
56
+ x = randn (3 )
57
+ @test f (x)() == x
58
+ end
59
+ end
60
+ end
61
+
62
+ end
Original file line number Diff line number Diff line change @@ -57,6 +57,7 @@ include("test_util.jl")
57
57
include (" utils.jl" )
58
58
include (" accumulators.jl" )
59
59
include (" compiler.jl" )
60
+ include (" pobserve_macro.jl" )
60
61
include (" varnamedvector.jl" )
61
62
include (" varinfo.jl" )
62
63
include (" simple_varinfo.jl" )
You can’t perform that action at this time.
0 commit comments