Skip to content

Commit 97c743d

Browse files
Solve the overflow in mean() on integers by promoting accumulator (#25)
1 parent 542f57e commit 97c743d

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

src/Statistics.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,17 @@ julia> mean([√1, √2, √3])
6161
function mean(f, itr)
6262
y = iterate(itr)
6363
if y === nothing
64-
return Base.mapreduce_empty_iter(f, Base.add_sum, itr,
64+
return Base.mapreduce_empty_iter(f, +, itr,
6565
Base.IteratorEltype(itr)) / 0
6666
end
6767
count = 1
6868
value, state = y
69-
f_value = f(value)
70-
total = Base.reduce_first(Base.add_sum, f_value)
69+
f_value = f(value)/1
70+
total = Base.reduce_first(+, f_value)
7171
y = iterate(itr, state)
7272
while y !== nothing
7373
value, state = y
74-
total += f(value)
74+
total += _mean_promote(total, f(value))
7575
count += 1
7676
y = iterate(itr, state)
7777
end
@@ -103,9 +103,6 @@ julia> mean(√, [1 2 3; 4 5 6], dims=2)
103103
"""
104104
mean(f, A::AbstractArray; dims=:) = _mean(f, A, dims)
105105

106-
_mean(f, A::AbstractArray, ::Colon) = sum(f, A) / length(A)
107-
_mean(f, A::AbstractArray, dims) = sum(f, A, dims=dims) / mapreduce(i -> size(A, i), *, unique(dims); init=1)
108-
109106
"""
110107
mean!(r, v)
111108
@@ -164,10 +161,25 @@ julia> mean(A, dims=2)
164161
3.5
165162
```
166163
"""
167-
mean(A::AbstractArray; dims=:) = _mean(A, dims)
164+
mean(A::AbstractArray; dims=:) = _mean(identity, A, dims)
165+
166+
_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)
168167

169-
_mean(A::AbstractArray{T}, region) where {T} = mean!(Base.reducedim_init(t -> t/2, +, A, region), A)
170-
_mean(A::AbstractArray, ::Colon) = sum(A) / length(A)
168+
function _mean(f, A::AbstractArray, dims=:)
169+
isempty(A) && return sum(f, A, dims=dims)/0
170+
if dims === (:)
171+
n = length(A)
172+
else
173+
n = mapreduce(i -> size(A, i), *, unique(dims); init=1)
174+
end
175+
x1 = f(first(A)) / 1
176+
result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims)
177+
if dims === (:)
178+
return result / n
179+
else
180+
return result ./= n
181+
end
182+
end
171183

172184
function mean(r::AbstractRange{<:Real})
173185
isempty(r) && return oftype((first(r) + last(r)) / 2, NaN)

test/runtests.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,23 @@ end
130130
@test mean(identity, x) == mean(identity, g) == typemax(T)
131131
@test mean(x, dims=2) == [typemax(T)]'
132132
end
133+
# Check that mean avoids integer overflow (#22)
134+
let x = fill(typemax(Int), 10), a = tuple(x...)
135+
@test (mean(x) == mean(x, dims=1)[] == mean(float, x)
136+
== mean(a) == mean(v for v in x) == mean(v for v in a)
137+
float(typemax(Int)))
138+
end
139+
let x = rand(10000) # mean should use sum's accurate pairwise algorithm
140+
@test mean(x) == sum(x) / length(x)
141+
end
142+
@test mean(Number[1, 1.5, 2+3im]) === 1.5+1im # mixed-type array
143+
@test mean(v for v in Number[1, 1.5, 2+3im]) === 1.5+1im
144+
@test (@inferred mean(Int[])) === 0/0
145+
@test (@inferred mean(Float32[])) === 0.f0/0
146+
@test (@inferred mean(Float64[])) === 0/0
147+
@test (@inferred mean(Iterators.filter(x -> true, Int[]))) === 0/0
148+
@test (@inferred mean(Iterators.filter(x -> true, Float32[]))) === 0.f0/0
149+
@test (@inferred mean(Iterators.filter(x -> true, Float64[]))) === 0/0
133150
end
134151

135152
@testset "mean/median for ranges" begin
@@ -710,7 +727,7 @@ end
710727
x = Any[1, 2, 4, 10]
711728
y = Any[1, 2, 4, 10//1]
712729
@test var(x) === 16.25
713-
@test var(y) === 65//4
730+
@test var(y) === 16.25
714731
@test std(x) === sqrt(16.25)
715732
@test quantile(x, 0.5) === 3.0
716733
@test quantile(x, 1//2) === 3//1

0 commit comments

Comments
 (0)