Skip to content

Commit 2709d8f

Browse files
committed
updating tests
1 parent e34c6b1 commit 2709d8f

File tree

3 files changed

+143
-59
lines changed

3 files changed

+143
-59
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@ Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
1717
ConcreteStructs = "0.2"
1818
ConstructionBase = "1.3"
1919
FillArrays = "0.12"
20+
KeywordCalls = "0.2"
2021
MLStyle = "0.4"
2122
MappedArrays = "0.4"
2223
Tricks = "0.1"
2324
julia = "1.3"
2425

2526
[extras]
27+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2628
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2729
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2830
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2931

3032
[targets]
31-
test = ["Test", "LinearAlgebra", "Statistics"]
33+
test = ["Test", "Aqua", "LinearAlgebra", "Statistics"]

src/parameterized.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ function Base.show(io::IO, μ::ParameterizedMeasure{N}) where {N}
2323
print(io, getfield(μ,:par))
2424
end
2525

26-
export asparams
2726

2827
# Allow things like
2928
#

test/runtests.jl

Lines changed: 140 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
using MeasureBase
21
using Test
32
using Base.Iterators: take
43
using Random
54
using LinearAlgebra
6-
using KeywordCalls
7-
using Statistics
85

6+
using MeasureBase
7+
8+
using Aqua
9+
Aqua.test_all(MeasureBase; ambiguities=false, unbound_args=false)
910

1011
function draw2(μ)
1112
x = rand(μ)
@@ -16,51 +17,54 @@ function draw2(μ)
1617
return (x,y)
1718
end
1819

19-
const sqrt2π = sqrt(2π)
20-
21-
@testset "Parameterized Measures" begin
22-
@measure Normal(μ,σ)
23-
@kwstruct Normal(μ)
24-
@kwstruct Normal()
25-
26-
MeasureBase.basemeasure(::Normal)= (1/sqrt2π) * Lebesgue(ℝ)
27-
MeasureBase.logdensity(d::Normal{(:μ,:σ)}, x) = -log(d.σ) - (x - d.μ)^2 / (2 * d.σ^2)
28-
MeasureBase.logdensity(d::Normal{(:μ,)}, x) = - (x - d.μ)^2 / 2
29-
MeasureBase.logdensity(d::Normal{()}, x) = - x^2 / 2
30-
31-
Base.rand(rng::Random.AbstractRNG, T::Type, d::Normal{(:μ,:σ)}) = d.μ + d.σ * randn(rng, T)
32-
Base.rand(rng::Random.AbstractRNG, T::Type, d::Normal{(:μ,)}) = d.μ + randn(rng, T)
33-
Base.rand(rng::Random.AbstractRNG, T::Type, d::Normal{()}) = randn(rng, T)
34-
35-
MeasureBase.representative(d::Normal{(:μ,:σ)}) = d.σ > 0.0 ? Lebesgue(ℝ) : Dirac(d.μ)
36-
MeasureBase.representative(d::Normal{(:μ,)}) = Lebesgue(ℝ)
37-
38-
# Leave this undefined to test fallback inference algorithm
39-
# MeasureBase.representative(::Normal) = Lebesgue(ℝ)
40-
41-
@test Normal(2,4) == Normal=2, σ=4)
42-
@test Normal=4, μ=2) == Normal=2, σ=4)
43-
@test logdensity(Normal(), 3) == logdensity(Normal(0,1), 3)
44-
45-
x = randn()
46-
@test_broken logdensity(Normal(3,2), Lebesgue(ℝ), x) logdensity(Normal(3,2), Normal(), x ) + logdensity(Normal(), Lebesgue(ℝ),x)
47-
@test_broken 𝒹(Normal(3,2), Normal())(x) logdensity(Normal(3,2), Normal(), x)
20+
function test_measure(μ)
21+
logdensity(μ, testvalue(μ)) isa AbstractFloat
4822
end
4923

50-
@testset "Density" begin
51-
x = randn()
52-
f(x) = -x^2
53-
μ = Normal()
54-
ν = Lebesgue(ℝ)
55-
@test_broken 𝒹((f, μ), μ)(x) f(x)
56-
@test_broken logdensity((𝒹(μ, ν), ν), x) logdensity(μ, x)
24+
test_measures = [
25+
# Chain(x -> Normal(μ=x), Normal(μ=0.0))
26+
For(3) do j Normal=j) end
27+
For(2,3) do i,j Normal(i,j) end
28+
Lebesgue(ℝ) ^ 3
29+
Lebesgue(ℝ) ^ (2,3)
30+
3 * Lebesgue(ℝ)
31+
Dirac(π)
32+
Lebesgue(ℝ)
33+
# Normal() ⊙ Cauchy()
34+
]
35+
36+
testbroken_measures = [
37+
Pushforward(as𝕀, Normal())
38+
SpikeMixture(Normal(), 2)
39+
# InverseGamma(2) # Not defined yet
40+
# MvNormal(I(3)) # Entirely broken for now
41+
CountingMeasure(Float64)
42+
Likelihood
43+
Dirac(0.0) + Lebesgue(ℝ)
44+
45+
TrivialMeasure()
46+
]
47+
48+
@testset "testvalue" begin
49+
for μ in test_measures
50+
@test test_measure(μ)
51+
end
52+
53+
for μ in testbroken_measures
54+
@test_broken test_measure(μ)
55+
end
56+
57+
@testset "testvalue(::Chain)" begin
58+
mc = Chain(x -> Normal=x), Normal=0.0))
59+
r = testvalue(mc)
60+
@test logdensity(mc, Iterators.take(r, 10)) isa AbstractFloat
61+
end
5762
end
5863

5964

6065
@testset "Kernel" begin
61-
κ = kernel(identity, Dirac)
66+
κ = MeasureBase.kernel(MeasureBase.Dirac, identity)
6267
@test rand(κ(1.1)) == 1.1
63-
@test kernelize(Normal(0,1)) == (Kernel{Normal, UnionAll}(NamedTuple{(, ), T} where T<:Tuple), (0, 1))
6468
end
6569

6670
@testset "SpikeMixture" begin
@@ -74,6 +78,12 @@ end
7478
@test density(m, 0)*(bm.s*(1-bm.w)) (1-w)
7579
end
7680

81+
@testset "Dirac" begin
82+
@test rand(Dirac(0.2)) == 0.2
83+
@test logdensity(Dirac(0.3), 0.3) == 0.0
84+
@test logdensity(Dirac(0.3), 0.4) == -Inf
85+
end
86+
7787
@testset "For" begin
7888
FORDISTS = [
7989
For(1:10) do j Normal=j) end
@@ -88,22 +98,11 @@ end
8898
end
8999
end
90100

101+
import MeasureBase.:
91102
function ::Normal, kernel)
92103
m = kernel(μ)
93104
Normal= m.μ.μ, σ = sqrt(m.μ.σ^2 + m.σ^2))
94105
end
95-
96-
"""
97-
ConstantMap(β)
98-
Represents a function `f = ConstantMap(β)`
99-
such that `f(x) == β`.
100-
"""
101-
struct ConstantMap{T}
102-
x::T
103-
end
104-
(a::ConstantMap)(x) = a.x
105-
(a::ConstantMap)() = a.x
106-
107106
struct AffineMap{S,T}
108107
B::S
109108
β::T
@@ -125,8 +124,92 @@ end
125124
end
126125
end
127126

128-
@testset "LogLikelihood" begin
129-
d = Normal()
130-
= LogLikelihood(Normal{(,)}, 3.0)
131-
@test logdensity(d ℓ, 2.0) == logdensity(d, 2.0) + logdensity(ℓ, 2.0)
127+
@testset "Univariate chain" begin
128+
ξ0 = 1.
129+
x = 1.2
130+
P0 = 1.0
131+
132+
Φ = 0.8
133+
β = 0.1
134+
Q = 0.2
135+
136+
μ = Normal=ξ0, σ=sqrt(P0))
137+
kernel = MeasureBase.kernel(Normal; μ=AffineMap(Φ, β), σ=Const(Q))
138+
139+
@test kernel).μ == Normal= 0.9, σ = 0.824621).μ
140+
141+
chain = Chain(kernel, μ)
142+
143+
144+
dyniterate(iter::TimeLift, ::Nothing) = dyniterate(iter, 0=>nothing)
145+
tr1 = trace(TimeLift(chain), nothing, u -> u[1] > 15)
146+
tr2 = trace(TimeLift(rand(Random.GLOBAL_RNG, chain)), nothing, u -> u[1] > 15)
147+
collect(Iterators.take(chain, 10))
148+
collect(Iterators.take(rand(Random.GLOBAL_RNG, chain), 10))
149+
end
150+
151+
@testset "Likelihood" begin
152+
dps = [
153+
(Normal() , 2.0 )
154+
# (Pushforward(as((μ=asℝ,)), Normal()^1), (μ=2.0,))
155+
]
156+
157+
ℓs = [
158+
Likelihood(Normal{(,)}, 3.0)
159+
Likelihood(kernel(Normal, x ->=x, σ=2.0)), 3.0)
160+
]
161+
162+
for (d,p) in dps
163+
forin ℓs
164+
@test logdensity(d ℓ, p) == logdensity(d, p) + logdensity(ℓ, p)
165+
end
166+
end
167+
end
168+
169+
170+
@testset "ProductMeasure" begin
171+
d = For(1:10) do j Poisson(exp(j)) end
172+
x = Vector{Int16}(undef, 10)
173+
@test rand!(d,x) isa Vector
174+
@test rand(d) isa Vector
175+
176+
@testset "Indexed by Generator" begin
177+
d = For((j^2 for j in 1:10)) do i Poisson(i) end
178+
x = Vector{Int16}(undef, 10)
179+
@test rand!(d,x) isa Vector
180+
@test_broken rand(d) isa Base.Generator
181+
end
182+
183+
@testset "Indexed by multiple Ints" begin
184+
d = For(2,3) do μ,σ Normal(μ,σ) end
185+
x = Matrix{Float16}(undef, 2, 3)
186+
@test rand!(d, x) isa Matrix
187+
@test_broken rand(d) isa Matrix{Float16}
188+
end
189+
end
190+
191+
@testset "Show methods" begin
192+
@testset "PowerMeasure" begin
193+
@test repr(Lebesgue(ℝ) ^ 5) == "Lebesgue(ℝ) ^ 5"
194+
@test repr(Lebesgue(ℝ) ^ (3, 2)) == "Lebesgue(ℝ) ^ (3, 2)"
195+
end
196+
end
197+
198+
@testset "Density measures and Radon-Nikodym" begin
199+
x = randn()
200+
let d = (𝒹(Cauchy(), Normal()), Normal())
201+
@test logdensity(d, x) logdensity(Cauchy(), x)
202+
end
203+
204+
let f = 𝒹((x -> x^2, Normal()), Normal())
205+
@test f(x) x^2
206+
end
207+
208+
let d = ∫exp(log𝒹(Cauchy(), Normal()), Normal())
209+
@test logdensity(d, x) logdensity(Cauchy(), x)
210+
end
211+
212+
let f = log𝒹(∫exp(x -> x^2, Normal()), Normal())
213+
@test f(x) x^2
214+
end
132215
end

0 commit comments

Comments
 (0)