Skip to content

Commit aef7422

Browse files
authored
Fix cov2cor and cor2cov with Hermitian and Symmetric matrices (#953)
* Fix `cov2cor` and `cor2cov` with `Hermitian` and `Symmetric` matrices * Fix types and write only to active triangle * Fix tests and docs * Revert workaround for Julia 1.0
1 parent cc97650 commit aef7422

File tree

5 files changed

+163
-60
lines changed

5 files changed

+163
-60
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StatsBase"
22
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
33
authors = ["JuliaStats"]
4-
version = "0.34.4"
4+
version = "0.34.5"
55

66
[deps]
77
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"

docs/src/cov.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ std(::CovarianceEstimator, ::AbstractVector)
1313
cor
1414
mean_and_cov
1515
cov2cor
16+
StatsBase.cov2cor!
1617
cor2cov
18+
StatsBase.cor2cov!
1719
CovarianceEstimator
1820
SimpleCovariance
1921
```

src/StatsBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using SparseArrays
1818
import Random: rand, rand!
1919
import LinearAlgebra: BlasReal, BlasFloat
2020
import Statistics: mean, mean!, var, varm, varm!, std, stdm, cov, covm,
21-
cor, corm, cov2cor!, unscaled_covzm, quantile, sqrt!,
21+
cor, corm, unscaled_covzm, quantile, sqrt!,
2222
median, middle
2323
using StatsAPI: StatisticalModel, RegressionModel
2424
import StatsAPI: pairwise, pairwise!, params, params!,

src/cov.jl

Lines changed: 100 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,41 +128,124 @@ end
128128
cov2cor(C::AbstractMatrix, [s::AbstractArray])
129129
130130
Compute the correlation matrix from the covariance matrix `C` and, optionally, a vector
131-
of standard deviations `s`. Use `StatsBase.cov2cor!` for an in-place version.
131+
of standard deviations `s`. Use [`StatsBase.cov2cor!`](@ref) for an in-place version.
132132
"""
133-
cov2cor(C::AbstractMatrix, s::AbstractArray=sqrt.(view(C, diagind(C)))) = cov2cor!(copy(C), s)
133+
function cov2cor(C::AbstractMatrix, s::AbstractArray = map(sqrt, view(C, diagind(C))))
134+
zs = zero(eltype(s))
135+
T = typeof(zero(eltype(C)) / (zs * zs))
136+
return cov2cor!(copyto!(similar(C, T), C), s)
137+
end
138+
139+
# Original implementation: https://github.com/JuliaStats/Statistics.jl/blob/22dee82f9824d6045e87aa4b97e1d64fe6f01d8d/src/Statistics.jl#L633-L657
140+
"""
141+
cov2cor!(C::AbstractMatrix, [s::AbstractArray])
142+
143+
Convert the covariance matrix `C` to a correlation matrix in-place, optionally using a vector of
144+
standard deviations `s`.
145+
"""
146+
function cov2cor!(C::AbstractMatrix, s::AbstractArray = map(sqrt, view(C, diagind(C))))
147+
Base.require_one_based_indexing(C, s)
148+
n = length(s)
149+
size(C) == (n, n) || throw(DimensionMismatch("inconsistent dimensions"))
150+
for j = 1:n
151+
sj = s[j]
152+
for i = 1:(j-1)
153+
C[i,j] = adjoint(C[j,i])
154+
end
155+
C[j,j] = oneunit(C[j,j])
156+
for i = (j+1):n
157+
C[i,j] = _clampcor(C[i,j] / (s[i] * sj))
158+
end
159+
end
160+
return C
161+
end
162+
_clampcor(x::Real) = clamp(x, -1, 1)
163+
_clampcor(x) = x
164+
165+
# Preserve structure of Symmetric and Hermitian covariance matrices
166+
function cov2cor!(C::Union{Symmetric{<:Real},Hermitian}, s::AbstractArray)
167+
n = length(s)
168+
size(C) == (n, n) || throw(DimensionMismatch("inconsistent dimensions"))
169+
A = parent(C)
170+
if C.uplo === 'U'
171+
for j = 1:n
172+
sj = s[j]
173+
for i = 1:(j-1)
174+
A[i,j] = _clampcor(A[i,j] / (s[i] * sj))
175+
end
176+
A[j,j] = oneunit(A[j,j])
177+
end
178+
else
179+
for j = 1:n
180+
sj = s[j]
181+
A[j,j] = oneunit(A[j,j])
182+
for i = (j+1):n
183+
A[i,j] = _clampcor(A[i,j] / (s[i] * sj))
184+
end
185+
end
186+
end
187+
return C
188+
end
134189

135190
"""
136191
cor2cov(C, s)
137192
138193
Compute the covariance matrix from the correlation matrix `C` and a vector of standard
139-
deviations `s`. Use `StatsBase.cor2cov!` for an in-place version.
194+
deviations `s`. Use [`StatsBase.cor2cov!`](@ref) for an in-place version.
140195
"""
141-
cor2cov(C::AbstractMatrix, s::AbstractArray) = cor2cov!(copy(C), s)
196+
function cor2cov(C::AbstractMatrix, s::AbstractArray)
197+
zs = zero(eltype(s))
198+
T = typeof(zero(eltype(C)) * (zs * zs))
199+
return cor2cov!(copyto!(similar(C, T), C), s)
200+
end
142201

143202
"""
144203
cor2cov!(C, s)
145204
146-
Converts the correlation matrix `C` to a covariance matrix in-place using a vector of
205+
Convert the correlation matrix `C` to a covariance matrix in-place using a vector of
147206
standard deviations `s`.
148207
"""
149208
function cor2cov!(C::AbstractMatrix, s::AbstractArray)
209+
Base.require_one_based_indexing(C, s)
150210
n = length(s)
151211
size(C) == (n, n) || throw(DimensionMismatch("inconsistent dimensions"))
152-
for i in CartesianIndices(size(C))
153-
@inbounds C[i] *= s[i[1]] * s[i[2]]
212+
for j in 1:n
213+
sj = s[j]
214+
for i in 1:(j-1)
215+
C[i,j] = adjoint(C[j,i])
216+
end
217+
C[j,j] = sj^2
218+
for i in (j+1):n
219+
C[i,j] *= s[i] * sj
220+
end
154221
end
155222
return C
156223
end
157224

158-
"""
159-
_cov2cor!(C)
160-
161-
Compute the correlation matrix from the covariance matrix `C`, in-place.
162-
163-
The leading diagonal is used to determine the standard deviations by which to normalise.
164-
"""
165-
_cov2cor!(C::AbstractMatrix) = cov2cor!(C, sqrt.(diag(C)))
225+
# Preserve structure of Symmetric and Hermitian correlation matrices
226+
function cor2cov!(C::Union{Symmetric{<:Real},Hermitian}, s::AbstractArray)
227+
n = length(s)
228+
size(C) == (n, n) || throw(DimensionMismatch("inconsistent dimensions"))
229+
A = parent(C)
230+
if C.uplo === 'U'
231+
for j in 1:n
232+
sj = s[j]
233+
for i in 1:(j-1)
234+
A[i,j] *= s[i] * sj
235+
end
236+
A[j,j] = sj^2
237+
end
238+
else
239+
for j in 1:n
240+
sj = s[j]
241+
A[j,j] = sj^2
242+
for i in (j+1):n
243+
A[i,j] *= s[i] * sj
244+
end
245+
end
246+
end
247+
return C
248+
end
166249

167250
"""
168251
CovarianceEstimator
@@ -255,10 +338,10 @@ The keyword argument `mean` can be:
255338
of size `(N,1)`.
256339
"""
257340
function cor(ce::CovarianceEstimator, X::AbstractMatrix; kwargs...)
258-
return _cov2cor!(cov(ce, X; kwargs...))
341+
return cov2cor!(cov(ce, X; kwargs...))
259342
end
260343
function cor(ce::CovarianceEstimator, X::AbstractMatrix, w::AbstractWeights; kwargs...)
261-
return _cov2cor!(cov(ce, X, w; kwargs...))
344+
return cov2cor!(cov(ce, X, w; kwargs...))
262345
end
263346

264347
"""

test/cov.jl

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,27 @@ struct EmptyCovarianceEstimator <: CovarianceEstimator end
66
@testset "StatsBase.Covariance" begin
77
weight_funcs = (weights, aweights, fweights, pweights)
88

9+
function test_isapprox_preserves_symherm_structure(f::F, x::AbstractMatrix, y::AbstractMatrix, args...) where F
10+
for wrapper in (identity, x -> Symmetric(x, :U), x -> Symmetric(x, :L), x -> Hermitian(x, :U), x -> Hermitian(x, :L))
11+
A = wrapper(copy(x))
12+
fA = @inferred(f(A, args...))
13+
@test fA y
14+
if f === StatsBase.cov2cor! || f === StatsBase.cor2cov!
15+
@test fA === A
16+
if A isa Union{Symmetric,Hermitian}
17+
@test parent(fA) != fA # only active triangle is written to
18+
end
19+
else
20+
@test fA !== A
21+
if A isa Union{Symmetric,Hermitian}
22+
@test fA isa (A isa Symmetric ? Symmetric : Hermitian)
23+
@test fA.uplo == A.uplo
24+
@test parent(fA) != fA # only active triangle is written to
25+
end
26+
end
27+
end
28+
end
29+
930
@testset "$f" for f in weight_funcs
1031
X = randn(3, 8)
1132

@@ -120,18 +141,32 @@ weight_funcs = (weights, aweights, fweights, pweights)
120141
cor2 = cor(X, wv2, 2)
121142

122143
@testset "cov2cor" begin
123-
@test cov2cor(cov(X, dims = 1), std(X, dims = 1)) cor(X, dims = 1)
124-
@test cov2cor(cov(X, dims = 2), std(X, dims = 2)) cor(X, dims = 2)
125-
@test cov2cor(cov1) cor1
126-
@test cov2cor(cov2) cor2
127-
@test cov2cor(cov1, std1) cor1
128-
@test cov2cor(cov2, std2) cor2
144+
test_isapprox_preserves_symherm_structure(cov2cor, cov(X, dims = 1), cor(X, dims = 1), std(X, dims = 1))
145+
test_isapprox_preserves_symherm_structure(cov2cor, cov(X, dims = 2), cor(X, dims = 2), std(X, dims = 2))
146+
test_isapprox_preserves_symherm_structure(cov2cor, cov1, cor1)
147+
test_isapprox_preserves_symherm_structure(cov2cor, cov2, cor2)
148+
test_isapprox_preserves_symherm_structure(cov2cor, cov1, cor1, std1)
149+
test_isapprox_preserves_symherm_structure(cov2cor, cov2, cor2, std2)
150+
end
151+
@testset "StatsBase.cov2cor!" begin
152+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov(X, dims = 1), cor(X, dims = 1), std(X, dims = 1))
153+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov(X, dims = 2), cor(X, dims = 2), std(X, dims = 2))
154+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov1, cor1)
155+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov2, cor2)
156+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov1, cor1, std1)
157+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov2, cor2, std2)
129158
end
130159
@testset "cor2cov" begin
131-
@test cor2cov(cor(X, dims = 1), std(X, dims = 1)) cov(X, dims = 1)
132-
@test cor2cov(cor(X, dims = 2), std(X, dims = 2)) cov(X, dims = 2)
133-
@test cor2cov(cor1, std1) cov1
134-
@test cor2cov(cor2, std2) cov2
160+
test_isapprox_preserves_symherm_structure(cor2cov, cor(X, dims = 1), cov(X, dims = 1), std(X, dims = 1))
161+
test_isapprox_preserves_symherm_structure(cor2cov, cor(X, dims = 2), cov(X, dims = 2), std(X, dims = 2))
162+
test_isapprox_preserves_symherm_structure(cor2cov, cor1, cov1, std1)
163+
test_isapprox_preserves_symherm_structure(cor2cov, cor2, cov2, std2)
164+
end
165+
@testset "StatsBase.cor2cov!" begin
166+
test_isapprox_preserves_symherm_structure(StatsBase.cor2cov!, cor(X, dims = 1), cov(X, dims = 1), std(X, dims = 1))
167+
test_isapprox_preserves_symherm_structure(StatsBase.cor2cov!, cor(X, dims = 2), cov(X, dims = 2), std(X, dims = 2))
168+
test_isapprox_preserves_symherm_structure(StatsBase.cor2cov!, cor1, cov1, std1)
169+
test_isapprox_preserves_symherm_structure(StatsBase.cor2cov!, cor2, cov2, std2)
135170
end
136171
end
137172
end
@@ -198,41 +233,24 @@ weight_funcs = (weights, aweights, fweights, pweights)
198233
cor2 = cor(X, wv2, 2)
199234

200235
@testset "cov2cor" begin
201-
@test cov2cor(cov(X, dims = 1), std(X, dims = 1)) cor(X, dims = 1)
202-
@test cov2cor(cov(X, dims = 2), std(X, dims = 2)) cor(X, dims = 2)
203-
@test cov2cor(cov1, std1) cor1
204-
@test cov2cor(cov2, std2) cor2
236+
test_isapprox_preserves_symherm_structure(cov2cor, cov1, cor1)
237+
test_isapprox_preserves_symherm_structure(cov2cor, cov2, cor2)
238+
test_isapprox_preserves_symherm_structure(cov2cor, cov1, cor1, std1)
239+
test_isapprox_preserves_symherm_structure(cov2cor, cov2, cor2, std2)
205240
end
206-
207-
@testset "cov2cor!" begin
208-
tmp_cov1 = copy(cov1)
209-
@test !(tmp_cov1 cor1)
210-
StatsBase.cov2cor!(tmp_cov1, std1)
211-
@test tmp_cov1 cor1
212-
213-
tmp_cov2 = copy(cov2)
214-
@test !(tmp_cov2 cor2)
215-
StatsBase.cov2cor!(tmp_cov2, std2)
216-
@test tmp_cov2 cor2
241+
@testset "StatsBase.cov2cor!" begin
242+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov1, cor1)
243+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov2, cor2)
244+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov1, cor1, std1)
245+
test_isapprox_preserves_symherm_structure(StatsBase.cov2cor!, cov2, cor2, std2)
217246
end
218-
219247
@testset "cor2cov" begin
220-
@test cor2cov(cor(X, dims = 1), std(X, dims = 1)) cov(X, dims = 1)
221-
@test cor2cov(cor(X, dims = 2), std(X, dims = 2)) cov(X, dims = 2)
222-
@test cor2cov(cor1, std1) cov1
223-
@test cor2cov(cor2, std2) cov2
248+
test_isapprox_preserves_symherm_structure(cor2cov, cor1, cov1, std1)
249+
test_isapprox_preserves_symherm_structure(cor2cov, cor2, cov2, std2)
224250
end
225-
226-
@testset "cor2cov!" begin
227-
tmp_cor1 = copy(cor1)
228-
@test !(tmp_cor1 cov1)
229-
StatsBase.cor2cov!(tmp_cor1, std1)
230-
@test tmp_cor1 cov1
231-
232-
tmp_cor2 = copy(cor2)
233-
@test !(tmp_cor2 cov2)
234-
StatsBase.cor2cov!(tmp_cor2, std2)
235-
@test tmp_cor2 cov2
251+
@testset "StatsBase.cor2cov!" begin
252+
test_isapprox_preserves_symherm_structure(StatsBase.cor2cov!, cor1, cov1, std1)
253+
test_isapprox_preserves_symherm_structure(StatsBase.cor2cov!, cor2, cov2, std2)
236254
end
237255
end
238256
end

0 commit comments

Comments
 (0)