|
1 | 1 | using StatsBase |
2 | 2 | using LinearAlgebra, Random, Test |
3 | 3 |
|
| 4 | +struct EmptyCovarianceEstimator <: CovarianceEstimator end |
| 5 | + |
4 | 6 | @testset "StatsBase.Covariance" begin |
5 | 7 | weight_funcs = (weights, aweights, fweights, pweights) |
6 | 8 |
|
@@ -248,5 +250,50 @@ weight_funcs = (weights, aweights, fweights, pweights) |
248 | 250 | @test cor(X, wv1, 1) ≈ expected_cor1 |
249 | 251 | @test cor(X, wv2, 2) ≈ expected_cor2 |
250 | 252 | end |
| 253 | + |
| 254 | + @testset "Abstract covariance estimation" begin |
| 255 | + Xm1 = mean(X, dims=1) |
| 256 | + Xm2 = mean(X, dims=2) |
| 257 | + |
| 258 | + for corrected ∈ (false, true) |
| 259 | + scc = SimpleCovariance(corrected=corrected) |
| 260 | + @test_throws ArgumentError cov(scc, X, dims=0) |
| 261 | + @test_throws ArgumentError cov(scc, X, wv1, dims=0) |
| 262 | + @test cov(scc, X) ≈ cov(X, corrected=corrected) |
| 263 | + @test cov(scc, X, mean=Xm1) ≈ StatsBase.covm(X, Xm1, corrected=corrected) |
| 264 | + @test cov(scc, X, mean=Xm2, dims=2) ≈ StatsBase.covm(X, Xm2, 2, corrected=corrected) |
| 265 | + if f !== weights || corrected === false |
| 266 | + @test cov(scc, X, wv1, dims=1) ≈ cov(X, wv1, 1, corrected=corrected) |
| 267 | + @test cov(scc, X, wv2, dims=2) ≈ cov(X, wv2, 2, corrected=corrected) |
| 268 | + @test cov(scc, X, wv1, mean=Xm1) ≈ StatsBase.covm(X, Xm1, wv1, corrected=corrected) |
| 269 | + @test cov(scc, X, wv2, mean=Xm2, dims=2) ≈ StatsBase.covm(X, Xm2, wv2, 2, corrected=corrected) |
| 270 | + end |
| 271 | + end |
| 272 | + end |
| 273 | +end |
| 274 | + |
| 275 | +@testset "Abstract covariance estimation" begin |
| 276 | + est = EmptyCovarianceEstimator() |
| 277 | + wv = fweights(rand(2)) |
| 278 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0]) |
| 279 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv) |
| 280 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], dims = 2) |
| 281 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv, dims = 2) |
| 282 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], mean = nothing) |
| 283 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv, mean = nothing) |
| 284 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], dims = 2, mean = nothing) |
| 285 | + @test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv, dims = 2, mean = nothing) |
| 286 | + @test_throws ErrorException cov(est, [1.0, 2.0], [3.0, 4.0]) |
| 287 | + @test_throws ErrorException cov(est, [1.0, 2.0]) |
| 288 | + |
| 289 | + x = rand(8) |
| 290 | + y = rand(8) |
| 291 | + |
| 292 | + for corrected ∈ (false, true) |
| 293 | + @test_throws MethodError SimpleCovariance(corrected) |
| 294 | + scc = SimpleCovariance(corrected=corrected) |
| 295 | + @test cov(scc, x) ≈ cov(x; corrected=corrected) |
| 296 | + @test cov(scc, x, y) ≈ cov(x, y; corrected=corrected) |
| 297 | + end |
251 | 298 | end |
252 | 299 | end # @testset "StatsBase.Covariance" |
0 commit comments