Skip to content

Commit 4c659e3

Browse files
committed
Add more tests for dot tilde implementations (#247)
This PR extends the tests that were introduced in #245. Co-authored-by: David Widmann <[email protected]>
1 parent 0f7548d commit 4c659e3

File tree

3 files changed

+50
-16
lines changed

3 files changed

+50
-16
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
77
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
88
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1112
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1213
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

test/context_implementations.jl

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,58 @@
1515
end
1616

1717
# https://github.com/TuringLang/DynamicPPL.jl/issues/28#issuecomment-829223577
18-
@testset "arrays of distributions" begin
19-
@model function test(x, y)
20-
y .~ Normal.(x)
18+
@testset "dot tilde: arrays of distributions" begin
19+
@testset "assume" begin
20+
@model function test(x, size)
21+
y = Array{Float64,length(size)}(undef, size...)
22+
y .~ Normal.(x)
23+
return y, getlogp(__varinfo__)
24+
end
25+
26+
for ysize in ((2,), (2, 3), (2, 3, 4))
27+
for x in (
28+
# scalar
29+
randn(),
30+
# drop trailing dimensions
31+
ntuple(i -> randn(ysize[1:i]), length(ysize))...,
32+
# singleton dimensions
33+
ntuple(
34+
i -> randn(ysize[1:(i-1)]..., 1, ysize[(i+1):end]...),
35+
length(ysize),
36+
)...,
37+
)
38+
model = test(x, ysize)
39+
y, lp = model()
40+
@test lp sum(logpdf.(Normal.(x), y))
41+
42+
ys = [first(model()) for _ in 1:10_000]
43+
@test norm(mean(ys) .- x, Inf) < 0.1
44+
@test norm(std(ys) .- 1, Inf) < 0.1
45+
end
46+
end
2147
end
2248

23-
for ysize in ((2,), (2, 3), (2, 3, 4))
24-
# drop trailing dimensions
25-
for xsize in ntuple(i -> ysize[1:i], length(ysize))
26-
x = randn(xsize)
27-
y = randn(ysize)
28-
z = logjoint(test(x, y), VarInfo())
29-
@test z sum(logpdf.(Normal.(x), y))
49+
@testset "observe" begin
50+
@model function test(x, y)
51+
y .~ Normal.(x)
3052
end
3153

32-
# singleton dimensions
33-
for xsize in ntuple(i -> (ysize[1:(i-1)]..., 1, ysize[(i+1):end]...), length(ysize))
34-
x = randn(xsize)
35-
y = randn(ysize)
36-
z = logjoint(test(x, y), VarInfo())
37-
@test z sum(logpdf.(Normal.(x), y))
54+
for ysize in ((2,), (2, 3), (2, 3, 4))
55+
for x in (
56+
# scalar
57+
randn(),
58+
# drop trailing dimensions
59+
ntuple(i -> randn(ysize[1:i]), length(ysize))...,
60+
# singleton dimensions
61+
ntuple(
62+
i -> randn(ysize[1:(i-1)]..., 1, ysize[(i+1):end]...),
63+
length(ysize),
64+
)...,
65+
)
66+
y = randn(ysize)
67+
z = logjoint(test(x, y), VarInfo())
68+
@test z sum(logpdf.(Normal.(x), y))
69+
end
3870
end
3971
end
4072
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using Tracker
1212
using Zygote
1313

1414
using Distributed
15+
using LinearAlgebra
1516
using Pkg
1617
using Random
1718
using Serialization

0 commit comments

Comments
 (0)