Skip to content

Commit a87c385

Browse files
authored
Add missing shape checks for the means argument to var[m] and std[m] (#32)
We use `@inbounds`, but the shape of the `means` argument was never checked for the general `AbstractArray` method. With an incorrect shape, invalid results or crashes would happen. To avoid breaking existing code which was working, allow trailing singleton dimensions. Sync the `SparseMatrixCSV` method, which makes it less strict.
1 parent 97c743d commit a87c385

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

src/Statistics.jl

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,11 @@ centralize_sumabs2(A::AbstractArray, m, ifirst::Int, ilast::Int) =
247247
function centralize_sumabs2!(R::AbstractArray{S}, A::AbstractArray, means::AbstractArray) where S
248248
# following the implementation of _mapreducedim! at base/reducedim.jl
249249
lsiz = Base.check_reducedims(R,A)
250+
for i in 1:max(ndims(R), ndims(means))
251+
if axes(means, i) != axes(R, i)
252+
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
253+
end
254+
end
250255
isempty(R) || fill!(R, zero(S))
251256
isempty(A) && return R
252257

@@ -295,9 +300,9 @@ function varm!(R::AbstractArray{S}, A::AbstractArray, m::AbstractArray; correcte
295300
end
296301

297302
"""
298-
varm(itr, m; dims, corrected::Bool=true)
303+
varm(itr, mean; dims, corrected::Bool=true)
299304
300-
Compute the sample variance of collection `itr`, with known mean(s) `m`.
305+
Compute the sample variance of collection `itr`, with known mean(s) `mean`.
301306
302307
The algorithm returns an estimator of the generative distribution's variance
303308
under the assumption that each entry of `itr` is an IID drawn from that generative
@@ -308,7 +313,8 @@ whereas the sum is scaled with `n` if `corrected` is
308313
`false` with `n` the number of elements in `itr`.
309314
310315
If `itr` is an `AbstractArray`, `dims` can be provided to compute the variance
311-
over dimensions, and `m` may contain means for each dimension of `itr`.
316+
over dimensions. In that case, `mean` must be an array with the same shape as
317+
`mean(itr, dims=dims)` (additional trailing singleton dimensions are allowed).
312318
313319
!!! note
314320
If array contains `NaN` or [`missing`](@ref) values, the result is also
@@ -331,7 +337,7 @@ end
331337

332338

333339
"""
334-
var(itr; dims, corrected::Bool=true, mean=nothing)
340+
var(itr; corrected::Bool=true, mean=nothing[, dims])
335341
336342
Compute the sample variance of collection `itr`.
337343
@@ -343,10 +349,12 @@ If `corrected` is `true`, then the sum is scaled with `n-1`,
343349
whereas the sum is scaled with `n` if `corrected` is
344350
`false` with `n` the number of elements in `itr`.
345351
346-
A pre-computed `mean` may be provided.
347-
348352
If `itr` is an `AbstractArray`, `dims` can be provided to compute the variance
349-
over dimensions, and `mean` may contain means for each dimension of `itr`.
353+
over dimensions.
354+
355+
A pre-computed `mean` may be provided. When `dims` is specified, `mean` must be
356+
an array with the same shape as `mean(itr, dims=dims)` (additional trailing
357+
singleton dimensions are allowed).
350358
351359
!!! note
352360
If array contains `NaN` or [`missing`](@ref) values, the result is also
@@ -416,11 +424,13 @@ If `corrected` is `true`, then the sum is scaled with `n-1`,
416424
whereas the sum is scaled with `n` if `corrected` is
417425
`false` with `n` the number of elements in `itr`.
418426
419-
A pre-computed `mean` may be provided.
420-
421427
If `itr` is an `AbstractArray`, `dims` can be provided to compute the standard deviation
422428
over dimensions, and `means` may contain means for each dimension of `itr`.
423429
430+
A pre-computed `mean` may be provided. When `dims` is specified, `mean` must be
431+
an array with the same shape as `mean(itr, dims=dims)` (additional trailing
432+
singleton dimensions are allowed).
433+
424434
!!! note
425435
If array contains `NaN` or [`missing`](@ref) values, the result is also
426436
`NaN` or `missing` (`missing` takes precedence if array contains both).
@@ -445,9 +455,9 @@ std(iterable; corrected::Bool=true, mean=nothing) =
445455
sqrt(var(iterable, corrected=corrected, mean=mean))
446456

447457
"""
448-
stdm(itr, m; corrected::Bool=true)
458+
stdm(itr, mean; corrected::Bool=true)
449459
450-
Compute the sample standard deviation of collection `itr`, with known mean(s) `m`.
460+
Compute the sample standard deviation of collection `itr`, with known mean(s) `mean`.
451461
452462
The algorithm returns an estimator of the generative distribution's standard
453463
deviation under the assumption that each entry of `itr` is an IID drawn from that generative
@@ -457,10 +467,9 @@ If `corrected` is `true`, then the sum is scaled with `n-1`,
457467
whereas the sum is scaled with `n` if `corrected` is
458468
`false` with `n` the number of elements in `itr`.
459469
460-
A pre-computed `mean` may be provided.
461-
462470
If `itr` is an `AbstractArray`, `dims` can be provided to compute the standard deviation
463-
over dimensions, and `m` may contain means for each dimension of `itr`.
471+
over dimensions. In that case, `mean` must be an array with the same shape as
472+
`mean(itr, dims=dims)` (additional trailing singleton dimensions are allowed).
464473
465474
!!! note
466475
If array contains `NaN` or [`missing`](@ref) values, the result is also
@@ -1065,7 +1074,11 @@ end
10651074
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
10661075
require_one_based_indexing(R, A, means)
10671076
lsiz = Base.check_reducedims(R,A)
1068-
size(means) == size(R) || error("size of means must match size of R")
1077+
for i in 1:max(ndims(R), ndims(means))
1078+
if axes(means, i) != axes(R, i)
1079+
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
1080+
end
1081+
end
10691082
isempty(R) || fill!(R, zero(S))
10701083
isempty(A) && return R
10711084

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,22 @@ end
311311
@test var(Int[]) isa Float64
312312
@test isequal(var(skipmissing(Int[])), NaN)
313313
@test var(skipmissing(Int[])) isa Float64
314+
315+
# over dimensions with provided means
316+
for x in ([1 2 3; 4 5 6], sparse([1 2 3; 4 5 6]))
317+
@test var(x, dims=1, mean=mean(x, dims=1)) == var(x, dims=1)
318+
@test var(x, dims=1, mean=reshape(mean(x, dims=1), 1, :, 1)) == var(x, dims=1)
319+
@test var(x, dims=2, mean=mean(x, dims=2)) == var(x, dims=2)
320+
@test var(x, dims=2, mean=reshape(mean(x, dims=2), :)) == var(x, dims=2)
321+
@test var(x, dims=2, mean=reshape(mean(x, dims=2), :, 1, 1)) == var(x, dims=2)
322+
@test_throws DimensionMismatch var(x, dims=1, mean=ones(size(x, 1)))
323+
@test_throws DimensionMismatch var(x, dims=1, mean=ones(size(x, 1), 1))
324+
@test_throws DimensionMismatch var(x, dims=2, mean=ones(1, size(x, 2)))
325+
@test_throws DimensionMismatch var(x, dims=1, mean=ones(1, 1, size(x, 2)))
326+
@test_throws DimensionMismatch var(x, dims=2, mean=ones(1, size(x, 2), 1))
327+
@test_throws DimensionMismatch var(x, dims=2, mean=ones(size(x, 1), 1, 5))
328+
@test_throws DimensionMismatch var(x, dims=1, mean=ones(1, size(x, 2), 5))
329+
end
314330
end
315331

316332
function safe_cov(x, y, zm::Bool, cr::Bool)

0 commit comments

Comments
 (0)