Skip to content

Commit e666d74

Browse files
authored
Move test utilities to an extension (#1791)
* Move test utilities to an extension * Fix signature and docstring * Also qualify AbstractRNG * Fix Julia < 1.9 * Fix for 1.3? * Simplify the TestUtils stub
1 parent e407fa5 commit e666d74

File tree

4 files changed

+121
-94
lines changed

4 files changed

+121
-94
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Distributions"
22
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
authors = ["JuliaStats"]
4-
version = "0.25.102"
4+
version = "0.25.103"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -22,10 +22,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2222
[weakdeps]
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2424
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
25+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2526

2627
[extensions]
2728
DistributionsChainRulesCoreExt = "ChainRulesCore"
2829
DistributionsDensityInterfaceExt = "DensityInterface"
30+
DistributionsTestExt = "Test"
2931

3032
[compat]
3133
ChainRulesCore = "1"

ext/DistributionsTestExt.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
module DistributionsTestExt
2+
3+
using Distributions
4+
using Distributions.LinearAlgebra
5+
using Distributions.Random
6+
using Test
7+
8+
__rand(::Nothing, args...) = rand(args...)
9+
__rand(rng::AbstractRNG, args...) = rand(rng, args...)
10+
11+
__rand!(::Nothing, args...) = rand!(args...)
12+
__rand!(rng::AbstractRNG, args...) = rand!(rng, args...)
13+
14+
"""
15+
test_mvnormal(
16+
g::AbstractMvNormal,
17+
n_tsamples::Int=10^6,
18+
rng::Union{Random.AbstractRNG, Nothing}=nothing,
19+
)
20+
21+
Test that `AbstractMvNormal` implements the expected API.
22+
23+
!!! Note
24+
On Julia >= 1.9, you have to load the `Test` standard library to be able to use
25+
this function.
26+
"""
27+
function Distributions.TestUtils.test_mvnormal(
28+
g::AbstractMvNormal, n_tsamples::Int=10^6, rng::Union{AbstractRNG, Nothing}=nothing
29+
)
30+
d = length(g)
31+
μ = mean(g)
32+
Σ = cov(g)
33+
@test length(μ) == d
34+
@test size(Σ) == (d, d)
35+
@test var(g) diag(Σ)
36+
@test entropy(g) 0.5 * logdet(2π ** Σ)
37+
ldcov = logdetcov(g)
38+
@test ldcov logdet(Σ)
39+
vs = diag(Σ)
40+
@test g == typeof(g)(params(g)...)
41+
@test g == deepcopy(g)
42+
@test minimum(g) == fill(-Inf, d)
43+
@test maximum(g) == fill(Inf, d)
44+
@test extrema(g) == (minimum(g), maximum(g))
45+
@test isless(extrema(g)...)
46+
47+
# test sampling for AbstractMatrix (here, a SubArray):
48+
subX = view(__rand(rng, d, 2d), :, 1:d)
49+
@test isa(__rand!(rng, g, subX), SubArray)
50+
51+
# sampling
52+
@test isa(__rand(rng, g), Vector{Float64})
53+
X = __rand(rng, g, n_tsamples)
54+
emp_mu = vec(mean(X, dims=2))
55+
Z = X .- emp_mu
56+
emp_cov = (Z * Z') * inv(n_tsamples)
57+
58+
mean_atols = 8 .* sqrt.(vs ./ n_tsamples)
59+
cov_atols = 10 .* sqrt.(vs .* vs') ./ sqrt.(n_tsamples)
60+
for i = 1:d
61+
@test isapprox(emp_mu[i], μ[i], atol=mean_atols[i])
62+
end
63+
for i = 1:d, j = 1:d
64+
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
65+
end
66+
67+
X = rand(MersenneTwister(14), g, n_tsamples)
68+
Y = rand(MersenneTwister(14), g, n_tsamples)
69+
@test X == Y
70+
emp_mu = vec(mean(X, dims=2))
71+
Z = X .- emp_mu
72+
emp_cov = (Z * Z') * inv(n_tsamples)
73+
for i = 1:d
74+
@test isapprox(emp_mu[i] , μ[i] , atol=mean_atols[i])
75+
end
76+
for i = 1:d, j = 1:d
77+
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
78+
end
79+
80+
81+
# evaluation of sqmahal & logpdf
82+
U = X .- μ
83+
sqm = vec(sum(U .*\ U), dims=1))
84+
for i = 1:min(100, n_tsamples)
85+
@test sqmahal(g, X[:,i]) sqm[i]
86+
end
87+
@test sqmahal(g, X) sqm
88+
89+
lp = -0.5 .* sqm .- 0.5 * (d * log(2.0 * pi) + ldcov)
90+
for i = 1:min(100, n_tsamples)
91+
@test logpdf(g, X[:,i]) lp[i]
92+
end
93+
@test logpdf(g, X) lp
94+
95+
# log likelihood
96+
@test loglikelihood(g, X) sum(i -> Distributions._logpdf(g, X[:,i]), 1:n_tsamples)
97+
@test loglikelihood(g, X[:, 1]) logpdf(g, X[:, 1])
98+
@test loglikelihood(g, [X[:, i] for i in axes(X, 2)]) loglikelihood(g, X)
99+
end
100+
101+
end # module

src/Distributions.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,16 @@ include("mixtures/unigmm.jl")
316316
# Interface for StatsAPI
317317
include("statsapi.jl")
318318

319+
# Testing utilities for other packages which implement distributions.
320+
include("test_utils.jl")
321+
319322
# Extensions: Implementation of DensityInterface and ChainRulesCore API
320323
if !isdefined(Base, :get_extension)
321324
include("../ext/DistributionsChainRulesCoreExt/DistributionsChainRulesCoreExt.jl")
322325
include("../ext/DistributionsDensityInterfaceExt.jl")
326+
include("../ext/DistributionsTestExt.jl")
323327
end
324328

325-
# Testing utilities for other packages which implement distributions.
326-
include("test_utils.jl")
327-
328329
include("deprecates.jl")
329330

330331
"""

src/test_utils.jl

Lines changed: 13 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,19 @@
11
module TestUtils
22

3-
using Distributions
4-
using LinearAlgebra
5-
using Random
6-
using Test
7-
8-
9-
__rand(::Nothing, args...) = rand(args...)
10-
__rand(rng::AbstractRNG, args...) = rand(rng, args...)
11-
12-
__rand!(::Nothing, args...) = rand!(args...)
13-
__rand!(rng::AbstractRNG, args...) = rand!(rng, args...)
14-
15-
"""
16-
test_mvnormal(
17-
g::AbstractMvNormal, n_tsamples::Int=10^6, rng::AbstractRNG=Random.default_rng()
18-
)
19-
20-
Test that `AbstractMvNormal` implements the expected API.
21-
"""
22-
function test_mvnormal(
23-
g::AbstractMvNormal, n_tsamples::Int=10^6, rng::Union{AbstractRNG, Nothing}=nothing
24-
)
25-
d = length(g)
26-
μ = mean(g)
27-
Σ = cov(g)
28-
@test length(μ) == d
29-
@test size(Σ) == (d, d)
30-
@test var(g) diag(Σ)
31-
@test entropy(g) 0.5 * logdet(2π ** Σ)
32-
ldcov = logdetcov(g)
33-
@test ldcov logdet(Σ)
34-
vs = diag(Σ)
35-
@test g == typeof(g)(params(g)...)
36-
@test g == deepcopy(g)
37-
@test minimum(g) == fill(-Inf, d)
38-
@test maximum(g) == fill(Inf, d)
39-
@test extrema(g) == (minimum(g), maximum(g))
40-
@test isless(extrema(g)...)
41-
42-
# test sampling for AbstractMatrix (here, a SubArray):
43-
subX = view(__rand(rng, d, 2d), :, 1:d)
44-
@test isa(__rand!(rng, g, subX), SubArray)
45-
46-
# sampling
47-
@test isa(__rand(rng, g), Vector{Float64})
48-
X = __rand(rng, g, n_tsamples)
49-
emp_mu = vec(mean(X, dims=2))
50-
Z = X .- emp_mu
51-
emp_cov = (Z * Z') * inv(n_tsamples)
52-
53-
mean_atols = 8 .* sqrt.(vs ./ n_tsamples)
54-
cov_atols = 10 .* sqrt.(vs .* vs') ./ sqrt.(n_tsamples)
55-
for i = 1:d
56-
@test isapprox(emp_mu[i], μ[i], atol=mean_atols[i])
57-
end
58-
for i = 1:d, j = 1:d
59-
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
3+
import ..Distributions
4+
5+
function test_mvnormal end
6+
7+
if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
8+
function __init__()
9+
# Better error message if users forget to load Test
10+
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
11+
if exc.f === test_mvnormal &&
12+
(Base.get_extension(Distributions, :DistributionsTestExt) === nothing)
13+
print(io, "\nDid you forget to load Test?")
14+
end
15+
end
6016
end
61-
62-
X = rand(MersenneTwister(14), g, n_tsamples)
63-
Y = rand(MersenneTwister(14), g, n_tsamples)
64-
@test X == Y
65-
emp_mu = vec(mean(X, dims=2))
66-
Z = X .- emp_mu
67-
emp_cov = (Z * Z') * inv(n_tsamples)
68-
for i = 1:d
69-
@test isapprox(emp_mu[i] , μ[i] , atol=mean_atols[i])
70-
end
71-
for i = 1:d, j = 1:d
72-
@test isapprox(emp_cov[i,j], Σ[i,j], atol=cov_atols[i,j])
73-
end
74-
75-
76-
# evaluation of sqmahal & logpdf
77-
U = X .- μ
78-
sqm = vec(sum(U .*\ U), dims=1))
79-
for i = 1:min(100, n_tsamples)
80-
@test sqmahal(g, X[:,i]) sqm[i]
81-
end
82-
@test sqmahal(g, X) sqm
83-
84-
lp = -0.5 .* sqm .- 0.5 * (d * log(2.0 * pi) + ldcov)
85-
for i = 1:min(100, n_tsamples)
86-
@test logpdf(g, X[:,i]) lp[i]
87-
end
88-
@test logpdf(g, X) lp
89-
90-
# log likelihood
91-
@test loglikelihood(g, X) sum(i -> Distributions._logpdf(g, X[:,i]), 1:n_tsamples)
92-
@test loglikelihood(g, X[:, 1]) logpdf(g, X[:, 1])
93-
@test loglikelihood(g, [X[:, i] for i in axes(X, 2)]) loglikelihood(g, X)
9417
end
9518

9619
end

0 commit comments

Comments
 (0)