Skip to content

Commit 5d46994

Browse files
committed
Overhaul of support for statistics functions
1 parent 82a7b75 commit 5d46994

File tree

4 files changed

+136
-67
lines changed

4 files changed

+136
-67
lines changed

src/array_of_similar_arrays.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,8 @@ Base.convert(R::Type{VectorOfSimilarVectors}, A::AbstractVector{<:AbstractVector
393393
Compute the sum of the elements vectors of `X`. Equivalent to `sum` of
394394
`flatview(X)` along the last dimension.
395395
"""
396-
Base.sum(X::AbstractVectorOfSimilarArrays{T,M}) where {T,M} = sum(flatview(X); dims = M + 1)
396+
Base.sum(X::AbstractVectorOfSimilarArrays{T,M}) where {T,M} =
397+
sum(flatview(X); dims = M + 1)[_ncolons(Val{M}())...]
397398

398399

399400
"""
@@ -404,7 +405,7 @@ Compute the mean of the elements vectors of `X`. Equivalent to `mean` of
404405
`flatview(X)` along the last dimension.
405406
"""
406407
Statistics.mean(X::AbstractVectorOfSimilarArrays{T,M}) where {T,M} =
407-
mean(flatview(X); dims = M + 1)
408+
mean(flatview(X); dims = M + 1)[_ncolons(Val{M}())...]
408409

409410

410411
"""
@@ -415,7 +416,19 @@ Compute the sample variance of the elements vectors of `X`. Equivalent to
415416
`var` of `flatview(X)` along the last dimension.
416417
"""
417418
Statistics.var(X::AbstractVectorOfSimilarArrays{T,M}; corrected::Bool = true) where {T,M} =
418-
var(flatview(X); dims = M + 1, corrected = corrected)
419+
var(flatview(X); dims = M + 1, corrected = corrected)[_ncolons(Val{M}())...]
420+
421+
422+
"""
423+
var(X::AbstractVectorOfSimilarArrays; corrected::Bool = true)
424+
var(X::AbstractVectorOfSimilarArrays, w::StatsBase.AbstractWeights; corrected::Bool = true)
425+
426+
Compute the sample standard deviation of the elements vectors of `X`.
427+
Compute the sample variance of the elements vectors of `X`. Equivalent to
428+
`std` of `flatview(X)` along the last dimension.
429+
"""
430+
Statistics.std(X::AbstractVectorOfSimilarArrays{T,M}; corrected::Bool = true) where {T,M} =
431+
std(flatview(X); dims = M + 1, corrected = corrected)[_ncolons(Val{M}())...]
419432

420433

421434
"""

src/statsbase_support.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).
22

33

4-
Base.sum(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights) where {T,M} = sum(flatview(X), w, M + 1)
4+
Base.sum(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights) where {T,M} =
5+
sum(flatview(X), w, dims = M + 1)[_ncolons(Val{M}())...]
56

67
Statistics.mean(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights) where {T,M} =
7-
vec(mean(flatview(X), w, dims = M + 1))
8+
mean(flatview(X), w, dims = M + 1)[_ncolons(Val{M}())...]
89

910
Statistics.var(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights; corrected::Bool = true) where {T,M} =
10-
vec(var(flatview(X), w, M + 1; corrected = corrected))
11+
var(flatview(X), w, M + 1; corrected = corrected)[_ncolons(Val{M}())...]
1112

1213
Statistics.std(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights; corrected::Bool = true) where {T,M} =
13-
vec(std(flatview(X), w, M+1; corrected = corrected))
14-
15-
Statistics.std(X::AbstractVectorOfSimilarArrays{T,M}; corrected::Bool = true) where {T,M} =
16-
std(flatview(X); corrected = corrected)
14+
std(flatview(X), w, M + 1; corrected = corrected)[_ncolons(Val{M}())...]
1715

1816
Statistics.cov(X::AbstractVectorOfSimilarVectors, w::StatsBase.AbstractWeights; corrected::Bool = true) =
1917
cov(flatview(X), w, 2; corrected = corrected)

test/array_of_similar_arrays.jl

Lines changed: 29 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using ElasticArrays
77
using UnsafeArrays
88

99
using Statistics
10+
using StatsBase: cov2cor
1011

1112
@testset "array_of_similar_arrays" begin
1213
function rand_flat_array(Val_N::Val{N}) where {N}
@@ -237,60 +238,43 @@ using Statistics
237238
end
238239

239240
@testset "stats" begin
240-
a1 = rand(1,5); a2 = rand(1,5); a3 = rand(1,5)
241-
mu_a1 = mean(a1); mu_a2 = mean(a2); mu_a3 = mean(a3)
242-
243-
VA = VectorOfSimilarArrays([a1, a2, a3])
244-
v1 = [1,2,3,4]
245-
v2 = v1.*2
246-
v2 = v2.+1
247-
VV = VectorOfSimilarVectors([v1,v2])
248-
249-
@testset "sum" begin
250-
VA_sum = @inferred(sum(VA))
251-
for i in 1:length(VA[1])
252-
@test @inferred(VA_sum[i]) == a1[i] + a2[i] + a3[i]
253-
end
254-
end
241+
VV = [rand(3) for i in 1:10]
242+
VV_aosa = ArrayOfSimilarArrays(VV)
255243

256-
@testset "mean" begin
257-
VA_mean = @inferred(mean(VA))
258-
for i in 1:length(VA[1])
259-
diff = @inferred(VA_mean[i]) - @inferred((a1[i]+a2[i]+a3[i])/3)
260-
@test isless(diff, eps(Float64))
261-
end
262-
end
244+
VA = [rand(2,3,3) for i in 1:10]
245+
VA_aosa = ArrayOfSimilarArrays(VA)
246+
247+
array_cmp(A, B) = (A B) && (size(A) == size(B))
248+
249+
function test_statistics_op(op::Function)
250+
@testset "$op" begin
251+
@test array_cmp(@inferred(op(VV_aosa)), op(VV))
263252

264-
@testset "var" begin
265-
VA_var = @inferred(var(VA))
266-
for i in 1:length(VA[1])
267-
diff = @inferred(var([a1[i], a2[i], a3[i]])) - VA_var[i]
268-
@test @inferred(isless(diff, eps(Float64)))
253+
if op in (var, cov)
254+
@test array_cmp(@inferred(op(VV_aosa, corrected = false)), op(VV, corrected = false))
255+
end
256+
257+
if (op in (sum, mean, var))
258+
@test array_cmp(@inferred(op(VA_aosa)), op(VA))
259+
end
269260
end
270261
end
271262

272-
@testset "cov" begin
273-
VV_cov = @inferred(cov(VV))
274-
VV_var = @inferred(var(VV))
275-
diff = VV_cov[1] + VV_cov[6] + VV_cov[11] + VV_cov[16] - sum(VV_var)
276-
@test @inferred(isless(diff, eps(Float64)))
277-
@test VV_cov == VV_cov'
278-
end
263+
test_statistics_op(sum)
264+
test_statistics_op(mean)
265+
test_statistics_op(var)
266+
test_statistics_op(std)
267+
test_statistics_op(cov)
279268

280269
@testset "cor" begin
281-
VV_cor = @inferred(cor(VV))
282-
diff = sum(VV_cor - (zeros(size(VV_cor)).+1))
283-
@test VV_cor' == VV_cor
284-
@test @inferred(isless(diff, eps(Float64)))
285-
end
286-
287-
a1 = a1 .- mu_a1; a2 = a2 .- mu_a2; a3 = a3 .- mu_a3
288-
@testset "centered" begin
289-
@test isapprox(mean(a1), 0, atol=eps(Float64))
290-
@test isapprox(mean(a2), 0, atol=eps(Float64))
291-
@test isapprox(mean(a3), 0, atol=eps(Float64))
270+
# Statistics.cor currently results in an error for Vector{Vector},
271+
# this should be considered a bug, though, since Statistics.cov
272+
# works fine.
273+
@test array_cmp(@inferred(cor(VV_aosa)), cov2cor(cov(VV), std(VV)))
292274
end
293275
end
276+
277+
294278
@testset "examples" begin
295279
A_flat = rand(2,3,4,5,6)
296280
A_nested = nestedview(A_flat, 2)

test/test_statsbase_support.jl

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,90 @@ using StatsBase
66
using Statistics
77

88
@testset "StatsBase support" begin
9-
r = rand(1,100)
10-
V = VectorOfSimilarVectors{Float64}(r)
11-
w = FrequencyWeights(rand(100))
12-
w0 = FrequencyWeights(vec(ones(100,1)))
13-
@test isapprox(@inferred(sum(V, w))[1], sum(r, w))
14-
@test isapprox(@inferred(mean(V,w))[1], mean(r, w))
15-
@test isapprox(@inferred(var(V,w, corrected=true))[1], var(r, w, corrected=true))
16-
@test isapprox(@inferred(std(V,w, corrected=true))[1], sqrt(var(V,w, corrected=true)[1]))
17-
@test isapprox(@inferred(std(V, corrected=true))[1], sqrt(var(V, w0, corrected=true)[1]))
18-
@test isapprox(@inferred(std(V, corrected=true))[1], std(V, w0, corrected=true)[1])
19-
@test isapprox(@inferred(cov(V, w))[1], var(r, w, corrected=true))
20-
@test isapprox(@inferred(cor(V,w))[1], 1.0)
9+
VV = [rand(3) for i in 1:10]
10+
VV_aosa = ArrayOfSimilarArrays(VV)
11+
12+
VA = [rand(2,3,3) for i in 1:10]
13+
VA_aosa = ArrayOfSimilarArrays(VA)
14+
15+
w = FrequencyWeights(rand(10))
16+
17+
array_cmp(A, B) = (A B) && (size(A) == size(B))
18+
19+
20+
# sum and mean for Vector{Vector} with weights currently fail with
21+
# the implementations in StatsBase. This should be considered a
22+
# bug in StatsBase, since Base and Statistics support sum and mean
23+
# for Vector{Vector} without weights. Also, adding products of vectors
24+
# and weights is perfectly natural, mathematically.
25+
26+
_sum(A::AbstractVector{<:AbstractArray}, w::AbstractWeights) =
27+
sum(A .* w)
28+
29+
_mean(A::AbstractVector{<:AbstractArray}, w::AbstractWeights) =
30+
_sum(A, w) ./ sum(w)
31+
32+
@testset "sum and mean" begin
33+
34+
@test array_cmp(@inferred(sum(VV_aosa, w)), _sum(VV, w))
35+
@test array_cmp(@inferred(sum(VA_aosa, w)), _sum(VA, w))
36+
37+
@test array_cmp(@inferred(mean(VV_aosa, w)), _mean(VV, w))
38+
@test array_cmp(@inferred(mean(VA_aosa, w)), _mean(VA, w))
39+
end
40+
41+
42+
# Weighted var and std are currently not supported for Vector{Vector} by
43+
# StatsBase. This should be considered a bug in StatsBase, since
44+
# unweighted var and std for Vector{Vector} are supported by Statistics.
45+
46+
function _var(A::AbstractVector{<:AbstractArray}, w::FrequencyWeights; corrected = true)
47+
wmean_A = _mean(A, w)
48+
wsum = sum(w)
49+
wsum_corr = corrected ? -1 : 0
50+
sum([(x .- wmean_A).^2 for x in A] .* w) ./ (wsum + wsum_corr)
51+
end
52+
53+
_std(A::AbstractVector{<:AbstractArray}, w::AbstractWeights; corrected = true) =
54+
sqrt.(_var(A, w, corrected = corrected))
55+
56+
@testset "var and std" begin
57+
@test array_cmp(@inferred(var(VV_aosa, w)), _var(VV_aosa, w))
58+
@test array_cmp(@inferred(var(VV_aosa, w, corrected = false)), _var(VV_aosa, w, corrected = false))
59+
@test array_cmp(@inferred(var(VA_aosa, w)), _var(VA_aosa, w))
60+
@test array_cmp(@inferred(var(VA_aosa, w, corrected = false)), _var(VA_aosa, w, corrected = false))
61+
62+
@test array_cmp(@inferred(std(VV_aosa, w)), _std(VV_aosa, w))
63+
@test array_cmp(@inferred(std(VV_aosa, w, corrected = false)), _std(VV_aosa, w, corrected = false))
64+
@test array_cmp(@inferred(std(VA_aosa, w)), _std(VA_aosa, w))
65+
@test array_cmp(@inferred(std(VA_aosa, w, corrected = false)), _std(VA_aosa, w, corrected = false))
66+
end
67+
68+
69+
# For weighted cov of Vector{Vector}, StatsBase currently returns a vector
70+
# instead of a matrix, with `cov(VV, fill(1, 10)) != cov(VV)`.
71+
# This should be considered a bug in StatsBase.
72+
73+
function _cov(A::AbstractVector{<:AbstractVector}, w::FrequencyWeights; corrected = true)
74+
wmean_A = _mean(A, w)
75+
wsum = sum(w)
76+
wsum_corr = corrected ? -1 : 0
77+
sum([[(A[i][j] - wmean_A[j]) * (A[i][k] - wmean_A[k]) * w[i] for j in eachindex(A[i]), k in eachindex(A[i])] for i in eachindex(A)]) ./ (wsum + wsum_corr)
78+
end
79+
80+
@testset "cov" begin
81+
@test array_cmp(@inferred(cov(VV_aosa, w)), _cov(VV_aosa, w))
82+
@test array_cmp(@inferred(cov(VV_aosa, w, corrected = false)), _cov(VV_aosa, w, corrected = false))
83+
end
84+
85+
86+
# Weighted cor is currently not supported for Vector{Vector} by StatsBase.
87+
# This should be considered a bug in StatsBase, since unweighted cor
88+
# for Vector{Vector} is supported by Statistics.
89+
90+
_cor(A::AbstractVector{<:AbstractVector}, w::AbstractWeights) = cov2cor(_cov(A, w), _std(A, w))
91+
92+
@testset "cor" begin
93+
@test array_cmp(@inferred(cor(VV_aosa, w)), _cor(VV, w))
94+
end
2195
end

0 commit comments

Comments
 (0)