Skip to content

Commit 4220d57

Browse files
committed
add (mostly broken, for now) tests
1 parent 9191d09 commit 4220d57

File tree

1 file changed

+144
-2
lines changed

1 file changed

+144
-2
lines changed

test/runtests.jl

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,148 @@
11
using MeasureBase
22
using Test
3+
using StatsFuns
4+
using Base.Iterators: take
5+
using Random
6+
using LinearAlgebra
37

4-
@testset "MeasureBase.jl" begin
5-
# Write your tests here.
8+
function draw2(μ)
9+
x = rand(μ)
10+
y = rand(μ)
11+
while x == y
12+
y = rand(μ)
13+
end
14+
return (x,y)
15+
end
16+
17+
@testset "Parameterized Measures" begin
18+
@testset "Binomial" begin
19+
D = Binomial{(:n, :p)}
20+
par = merge((n=20,),transform(asparams(D, (n=20,)), randn(1)))
21+
d = D(par)
22+
(n,p) = (par.n, par.p)
23+
logitp = logit(p)
24+
probitp = norminvcdf(p)
25+
y = rand(d)
26+
27+
= logdensity(Binomial(;n, p), y)
28+
@test logdensity(Binomial(;n, logitp), y)
29+
@test logdensity(Binomial(;n, probitp), y)
30+
31+
@test_broken logdensity(Binomial(n,p), CountingMeasure(ℤ[0:n]), x) binomlogpdf(n,p,x)
32+
end
33+
34+
@testset "NegativeBinomial" begin
35+
D = NegativeBinomial{(:r, :p)}
36+
par = transform(asparams(D), randn(2))
37+
d = D(par)
38+
(r,p) = (par.r, par.p)
39+
logitp = logit(p)
40+
λ = p * r / (1 - p)
41+
y = rand(d)
42+
43+
= logdensity(NegativeBinomial(;r, p), y)
44+
@test logdensity(NegativeBinomial(;r, logitp), y)
45+
@test logdensity(NegativeBinomial(;r, λ), y)
46+
47+
@test_broken logdensity(Binomial(n,p), CountingMeasure(ℤ[0:n]), x) binomlogpdf(n,p,x)
48+
end
49+
50+
@testset "Normal" begin
51+
D = Normal{(,)}
52+
par = transform(asparams(D), randn(2))
53+
d = D(par)
54+
@test params(d) == par
55+
56+
μ = par.μ
57+
σ = par.σ
58+
σ² = σ^2
59+
τ = 1/σ²
60+
logσ = log(σ)
61+
y = rand(d)
62+
63+
= logdensity(Normal(;μ,σ), y)
64+
@test logdensity(Normal(;μ,σ²), y)
65+
@test logdensity(Normal(;μ,τ), y)
66+
@test logdensity(Normal(;μ,logσ), y)
67+
end
68+
end
69+
70+
@testset "Kernel" begin
71+
κ = MeasureTheory.kernel(identity, MeasureTheory.Dirac)
72+
@test rand(κ(1.1)) == 1.1
73+
end
74+
75+
@testset "SpikeMixture" begin
76+
@test rand(SpikeMixture(Dirac(0), 0.5)) == 0
77+
@test rand(SpikeMixture(Dirac(1), 1.0)) == 1
78+
w = 1/3
79+
m = SpikeMixture(Normal(), w)
80+
bm = basemeasure(m)
81+
@test (bm.s*bm.w)*bm.m == 1.0*basemeasure(Normal())
82+
@test density(m, 1.0)*(bm.s*bm.w) == w*density(Normal(),1.0)
83+
@test density(m, 0)*(bm.s*(1-bm.w)) (1-w)
84+
end
85+
86+
@testset "Dirac" begin
87+
@test rand(Dirac(0.2)) == 0.2
88+
@test logdensity(Dirac(0.3), 0.3) == 0.0
89+
@test logdensity(Dirac(0.3), 0.4) == -Inf
90+
end
91+
92+
@testset "For" begin
93+
FORDISTS = [
94+
For(1:10) do j Normal=j) end
95+
For(4,3) do μ,σ Normal(μ,σ) end
96+
For(1:4, 1:4) do μ,σ Normal(μ,σ) end
97+
For(eachrow(rand(4,2))) do x Normal(x[1], x[2]) end
98+
For(rand(4), rand(4)) do μ,σ Normal(μ,σ) end
99+
]
100+
101+
for d in FORDISTS
102+
@test logdensity(d, rand(d)) isa Float64
103+
end
104+
end
105+
106+
import MeasureTheory.:
107+
function ::Normal, kernel)
108+
m = kernel(μ)
109+
Normal= m.μ.μ, σ = sqrt(m.μ.σ^2 + m.σ^2))
110+
end
111+
112+
"""
113+
ConstantMap(β)
114+
Represents a function `f = ConstantMap(β)`
115+
such that `f(x) == β`.
116+
"""
117+
struct ConstantMap{T}
118+
x::T
119+
end
120+
(a::ConstantMap)(x) = a.x
121+
(a::ConstantMap)() = a.x
122+
123+
struct AffineMap{S,T}
124+
B::S
125+
β::T
126+
end
127+
(a::AffineMap)(x) = a.B*x + a.β
128+
(a::AffineMap)(p::Normal) = Normal= a.B*mean(p) + a.β, σ = sqrt(a.B*p.σ^2*a.B'))
129+
130+
@testset "DynamicFor" begin
131+
mc = Chain(Normal=0.0)) do x Normal=x) end
132+
r = rand(mc)
133+
134+
# Check that `r` is now deterministic
135+
@test logdensity(mc, take(r, 100)) == logdensity(mc, take(r, 100))
136+
137+
d2 = For(r) do x Normal=x) end
138+
139+
@test_broken let r2 = rand(d2)
140+
logdensity(d2, take(r2, 100)) == logdensity(d2, take(r2, 100))
141+
end
142+
end
143+
144+
@testset "Likelihood" begin
145+
d = Normal()
146+
= Likelihood(Normal{(,)}, 3.0)
147+
@test logdensity(d ℓ, 2.0) == logdensity(d, 2.0) + logdensity(ℓ, 2.0)
6148
end

0 commit comments

Comments
 (0)