Skip to content

Commit 24a80ca

Browse files
mateuszbarannalimilan
authored andcommitted
Basic abstract covariance estimation interface (#460)
1 parent 5c9ca27 commit 24a80ca

File tree

4 files changed

+134
-0
lines changed

4 files changed

+134
-0
lines changed

docs/src/cov.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@ This package implements functions for computing scatter matrix, as well as weigh
55
```@docs
66
scattermat
77
cov
8+
cov(::CovarianceEstimator, ::AbstractVector)
9+
cov(::CovarianceEstimator, ::AbstractVector, ::AbstractVector)
10+
cov(::CovarianceEstimator, ::AbstractMatrix)
811
cor
912
mean_and_cov
1013
cov2cor
1114
cor2cov
15+
CovarianceEstimator
16+
SimpleCovariance
1217
```

src/StatsBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ export
9898
scattermat, # scatter matrix (i.e. unnormalized covariance)
9999
cov2cor, # converts a covariance matrix to a correlation matrix
100100
cor2cov, # converts a correlation matrix to a covariance matrix
101+
CovarianceEstimator, # abstract type for covariance estimators
102+
SimpleCovariance, # simple covariance estimator
101103

102104
## counts
103105
addcounts!, # add counts to an accumulating array or map

src/cov.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,83 @@ function cor2cov!(C::AbstractMatrix, s::AbstractArray)
153153
end
154154
return C
155155
end
156+
157+
"""
158+
CovarianceEstimator
159+
160+
Abstract type for covariance estimators.
161+
"""
162+
abstract type CovarianceEstimator end
163+
164+
"""
165+
cov(ce::CovarianceEstimator, x::AbstractVector; mean=nothing)
166+
167+
Compute a variance estimate from the observation vector `x` using the estimator `ce`.
168+
"""
169+
cov(ce::CovarianceEstimator, x::AbstractVector; mean=nothing) =
170+
error("cov is not defined for $(typeof(ce)) and $(typeof(x))")
171+
172+
"""
173+
cov(ce::CovarianceEstimator, x::AbstractVector, y::AbstractVector)
174+
175+
Compute the covariance of the vectors `x` and `y` using estimator `ce`.
176+
"""
177+
cov(ce::CovarianceEstimator, x::AbstractVector, y::AbstractVector) =
178+
error("cov is not defined for $(typeof(ce)), $(typeof(x)) and $(typeof(y))")
179+
180+
"""
181+
cov(ce::CovarianceEstimator, X::AbstractMatrix, [w::AbstractWeights]; mean=nothing, dims::Int=1)
182+
183+
Compute the covariance matrix of the matrix `X` along dimension `dims`
184+
using estimator `ce`. A weighting vector `w` can be specified.
185+
The keyword argument `mean` can be:
186+
187+
* `nothing` (default) in which case the mean is estimated and subtracted
188+
from the data `X`,
189+
* a precalculated mean in which case it is subtracted from the data `X`.
190+
Assuming `size(X)` is `(N,M)`, `mean` can either be:
191+
* when `dims=1`, an `AbstractMatrix` of size `(1,M)`,
192+
* when `dims=2`, an `AbstractVector` of length `N` or an `AbstractMatrix`
193+
of size `(N,1)`.
194+
"""
195+
cov(ce::CovarianceEstimator, X::AbstractMatrix; mean=nothing, dims::Int=1) =
196+
error("cov is not defined for $(typeof(ce)) and $(typeof(X))")
197+
198+
cov(ce::CovarianceEstimator, X::AbstractMatrix, w::AbstractWeights; mean=nothing, dims::Int=1) =
199+
error("cov is not defined for $(typeof(ce)), $(typeof(X)) and $(typeof(w))")
200+
201+
"""
202+
SimpleCovariance(;corrected::Bool=false)
203+
204+
Simple covariance estimator. Estimation calls `cov(x; corrected=corrected)`,
205+
`cov(x, y; corrected=corrected)` or `cov(X, w, dims; corrected=corrected)`
206+
where `x`, `y` are vectors, `X` is a matrix and `w` is a weighting vector.
207+
"""
208+
struct SimpleCovariance <: CovarianceEstimator
209+
corrected::Bool
210+
SimpleCovariance(;corrected::Bool=false) = new(corrected)
211+
end
212+
213+
cov(sc::SimpleCovariance, x::AbstractVector) =
214+
cov(x; corrected=sc.corrected)
215+
216+
cov(sc::SimpleCovariance, x::AbstractVector, y::AbstractVector) =
217+
cov(x, y; corrected=sc.corrected)
218+
219+
function cov(sc::SimpleCovariance, X::AbstractMatrix; dims::Int=1, mean=nothing)
220+
dims (1, 2) || throw(ArgumentError("Argument dims can only be 1 or 2 (given: $dims)"))
221+
if mean === nothing
222+
return cov(X; dims=dims, corrected=sc.corrected)
223+
else
224+
return covm(X, mean, dims, corrected=sc.corrected)
225+
end
226+
end
227+
228+
function cov(sc::SimpleCovariance, X::AbstractMatrix, w::AbstractWeights; dims::Int=1, mean=nothing)
229+
dims (1, 2) || throw(ArgumentError("Argument dims can only be 1 or 2 (given: $dims)"))
230+
if mean === nothing
231+
return cov(X, w, dims, corrected=sc.corrected)
232+
else
233+
return covm(X, mean, w, dims, corrected=sc.corrected)
234+
end
235+
end

test/cov.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using StatsBase
22
using LinearAlgebra, Random, Test
33

4+
struct EmptyCovarianceEstimator <: CovarianceEstimator end
5+
46
@testset "StatsBase.Covariance" begin
57
weight_funcs = (weights, aweights, fweights, pweights)
68

@@ -248,5 +250,50 @@ weight_funcs = (weights, aweights, fweights, pweights)
248250
@test cor(X, wv1, 1) expected_cor1
249251
@test cor(X, wv2, 2) expected_cor2
250252
end
253+
254+
@testset "Abstract covariance estimation" begin
255+
Xm1 = mean(X, dims=1)
256+
Xm2 = mean(X, dims=2)
257+
258+
for corrected (false, true)
259+
scc = SimpleCovariance(corrected=corrected)
260+
@test_throws ArgumentError cov(scc, X, dims=0)
261+
@test_throws ArgumentError cov(scc, X, wv1, dims=0)
262+
@test cov(scc, X) cov(X, corrected=corrected)
263+
@test cov(scc, X, mean=Xm1) StatsBase.covm(X, Xm1, corrected=corrected)
264+
@test cov(scc, X, mean=Xm2, dims=2) StatsBase.covm(X, Xm2, 2, corrected=corrected)
265+
if f !== weights || corrected === false
266+
@test cov(scc, X, wv1, dims=1) cov(X, wv1, 1, corrected=corrected)
267+
@test cov(scc, X, wv2, dims=2) cov(X, wv2, 2, corrected=corrected)
268+
@test cov(scc, X, wv1, mean=Xm1) StatsBase.covm(X, Xm1, wv1, corrected=corrected)
269+
@test cov(scc, X, wv2, mean=Xm2, dims=2) StatsBase.covm(X, Xm2, wv2, 2, corrected=corrected)
270+
end
271+
end
272+
end
273+
end
274+
275+
@testset "Abstract covariance estimation" begin
276+
est = EmptyCovarianceEstimator()
277+
wv = fweights(rand(2))
278+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0])
279+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv)
280+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], dims = 2)
281+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv, dims = 2)
282+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], mean = nothing)
283+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv, mean = nothing)
284+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], dims = 2, mean = nothing)
285+
@test_throws ErrorException cov(est, [1.0 2.0; 3.0 4.0], wv, dims = 2, mean = nothing)
286+
@test_throws ErrorException cov(est, [1.0, 2.0], [3.0, 4.0])
287+
@test_throws ErrorException cov(est, [1.0, 2.0])
288+
289+
x = rand(8)
290+
y = rand(8)
291+
292+
for corrected (false, true)
293+
@test_throws MethodError SimpleCovariance(corrected)
294+
scc = SimpleCovariance(corrected=corrected)
295+
@test cov(scc, x) cov(x; corrected=corrected)
296+
@test cov(scc, x, y) cov(x, y; corrected=corrected)
297+
end
251298
end
252299
end # @testset "StatsBase.Covariance"

0 commit comments

Comments
 (0)