Skip to content

Commit 77590ce

Browse files
authored
Let var, std & cor take a CovarianceEstimator (#815)
* Closes #734 * Undo drive-by tweak * Apparently needed to avoid duplicating docstrings * Bump version * Use private _cov2cor! for now
1 parent db4e40e commit 77590ce

File tree

4 files changed

+78
-2
lines changed

4 files changed

+78
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StatsBase"
22
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
33
authors = ["JuliaStats"]
4-
version = "0.33.18"
4+
version = "0.33.19"
55

66
[deps]
77
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"

docs/src/cov.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ cov
88
cov(::CovarianceEstimator, ::AbstractVector)
99
cov(::CovarianceEstimator, ::AbstractVector, ::AbstractVector)
1010
cov(::CovarianceEstimator, ::AbstractMatrix)
11+
var(::CovarianceEstimator, ::AbstractVector)
12+
std(::CovarianceEstimator, ::AbstractVector)
1113
cor
1214
mean_and_cov
1315
cov2cor

src/cov.jl

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function scattermat end
5050

5151

5252
"""
53-
cov(X, w::AbstractWeights, vardim=1; mean=nothing, corrected=false)
53+
cov(X, w::AbstractWeights, vardim=1; mean=nothing, corrected=false)
5454
5555
Compute the weighted covariance matrix. Similar to `var` and `std` the biased covariance
5656
matrix (`corrected=false`) is computed by multiplying `scattermat(X, w)` by
@@ -154,6 +154,15 @@ function cor2cov!(C::AbstractMatrix, s::AbstractArray)
154154
return C
155155
end
156156

157+
"""
158+
_cov2cor!(C)
159+
160+
Compute the correlation matrix from the covariance matrix `C`, in-place.
161+
162+
The leading diagonal is used to determine the standard deviations by which to normalise.
163+
"""
164+
_cov2cor!(C::AbstractMatrix) = cov2cor!(C, sqrt.(diag(C)))
165+
157166
"""
158167
CovarianceEstimator
159168
@@ -198,6 +207,60 @@ cov(ce::CovarianceEstimator, X::AbstractMatrix; mean=nothing, dims::Int=1) =
198207
cov(ce::CovarianceEstimator, X::AbstractMatrix, w::AbstractWeights; mean=nothing, dims::Int=1) =
199208
error("cov is not defined for $(typeof(ce)), $(typeof(X)) and $(typeof(w))")
200209

210+
"""
211+
var(ce::CovarianceEstimator, x::AbstractVector; mean=nothing)
212+
213+
Compute the variance of the vector `x` using the estimator `ce`.
214+
"""
215+
var(ce::CovarianceEstimator, x::AbstractVector; kwargs...) = cov(ce, x; kwargs...)
216+
217+
"""
218+
std(ce::CovarianceEstimator, x::AbstractVector; mean=nothing)
219+
220+
Compute the standard deviation of the vector `x` using the estimator `ce`.
221+
"""
222+
std(ce::CovarianceEstimator, x::AbstractVector; kwargs...) = sqrt(var(ce, x; kwargs...))
223+
224+
"""
225+
cor(ce::CovarianceEstimator, x::AbstractVector, y::AbstractVector)
226+
227+
Compute the correlation of the vectors `x` and `y` using estimator `ce`.
228+
"""
229+
function cor(ce::CovarianceEstimator, x::AbstractVector, y::AbstractVector)
230+
# Here we allow `ce` to see both `x` and `y` simultaneously, and allow it to compute
231+
# a full covariance matrix, from which we will extract the correlation.
232+
#
233+
# Whilst in some cases it might be equivalent (and possibly more efficient) to use:
234+
# cov(ce, x, y) / (std(ce, x) * std(ce, y)),
235+
# this need not apply in general.
236+
return cor(ce, hcat(x, y))[1, 2]
237+
end
238+
239+
"""
240+
cor(
241+
ce::CovarianceEstimator, X::AbstractMatrix, [w::AbstractWeights];
242+
mean=nothing, dims::Int=1
243+
)
244+
245+
Compute the correlation matrix of the matrix `X` along dimension `dims`
246+
using estimator `ce`. A weighting vector `w` can be specified.
247+
The keyword argument `mean` can be:
248+
249+
* `nothing` (default) in which case the mean is estimated and subtracted
250+
from the data `X`,
251+
* a precalculated mean in which case it is subtracted from the data `X`.
252+
Assuming `size(X)` is `(N,M)`, `mean` can either be:
253+
* when `dims=1`, an `AbstractMatrix` of size `(1,M)`,
254+
* when `dims=2`, an `AbstractVector` of length `N` or an `AbstractMatrix`
255+
of size `(N,1)`.
256+
"""
257+
function cor(ce::CovarianceEstimator, X::AbstractMatrix; kwargs...)
258+
return _cov2cor!(cov(ce, X; kwargs...))
259+
end
260+
function cor(ce::CovarianceEstimator, X::AbstractMatrix, w::AbstractWeights; kwargs...)
261+
return _cov2cor!(cov(ce, X, w; kwargs...))
262+
end
263+
201264
"""
202265
SimpleCovariance(;corrected::Bool=false)
203266

test/cov.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,23 @@ end
288288

289289
x = rand(8)
290290
y = rand(8)
291+
X = hcat(x, y)
291292

292293
for corrected (false, true)
293294
@test_throws MethodError SimpleCovariance(corrected)
294295
scc = SimpleCovariance(corrected=corrected)
295296
@test cov(scc, x) cov(x; corrected=corrected)
296297
@test cov(scc, x, y) cov(x, y; corrected=corrected)
298+
@test cov(scc, X) cov(X; corrected=corrected)
299+
300+
@test var(scc, x) var(x; corrected=corrected)
301+
@test std(scc, x) std(x; corrected=corrected)
302+
303+
# NB That we should get the same correlation regardless of `corrected`, since it
304+
# only affects the overall scale of the covariance. This cancels out when turning
305+
# it into a correlation matrix.
306+
@test cor(scc, x, y) cor(x, y)
307+
@test cor(scc, X) cor(X)
297308
end
298309
end
299310
end # @testset "StatsBase.Covariance"

0 commit comments

Comments
 (0)