@@ -61,17 +61,17 @@ julia> mean([√1, √2, √3])
6161function mean (f, itr)
6262 y = iterate (itr)
6363 if y === nothing
64- return Base. mapreduce_empty_iter (f, Base . add_sum , itr,
64+ return Base. mapreduce_empty_iter (f, + , itr,
6565 Base. IteratorEltype (itr)) / 0
6666 end
6767 count = 1
6868 value, state = y
69- f_value = f (value)
70- total = Base. reduce_first (Base . add_sum , f_value)
69+ f_value = f (value)/ 1
70+ total = Base. reduce_first (+ , f_value)
7171 y = iterate (itr, state)
7272 while y != = nothing
7373 value, state = y
74- total += f (value)
74+ total += _mean_promote (total, f (value) )
7575 count += 1
7676 y = iterate (itr, state)
7777 end
@@ -103,9 +103,6 @@ julia> mean(√, [1 2 3; 4 5 6], dims=2)
103103"""
104104mean (f, A:: AbstractArray ; dims= :) = _mean (f, A, dims)
105105
106- _mean (f, A:: AbstractArray , :: Colon ) = sum (f, A) / length (A)
107- _mean (f, A:: AbstractArray , dims) = sum (f, A, dims= dims) / mapreduce (i -> size (A, i), * , unique (dims); init= 1 )
108-
109106"""
110107 mean!(r, v)
111108
@@ -164,10 +161,25 @@ julia> mean(A, dims=2)
164161 3.5
165162```
166163"""
167- mean (A:: AbstractArray ; dims= :) = _mean (A, dims)
164+ mean (A:: AbstractArray ; dims= :) = _mean (identity, A, dims)
165+
166+ _mean_promote (x:: T , y:: S ) where {T,S} = convert (promote_type (T, S), y)
168167
169- _mean (A:: AbstractArray{T} , region) where {T} = mean! (Base. reducedim_init (t -> t/ 2 , + , A, region), A)
170- _mean (A:: AbstractArray , :: Colon ) = sum (A) / length (A)
168+ function _mean (f, A:: AbstractArray , dims= :)
169+ isempty (A) && return sum (f, A, dims= dims)/ 0
170+ if dims === (:)
171+ n = length (A)
172+ else
173+ n = mapreduce (i -> size (A, i), * , unique (dims); init= 1 )
174+ end
175+ x1 = f (first (A)) / 1
176+ result = sum (x -> _mean_promote (x1, f (x)), A, dims= dims)
177+ if dims === (:)
178+ return result / n
179+ else
180+ return result ./= n
181+ end
182+ end
171183
172184function mean (r:: AbstractRange{<:Real} )
173185 isempty (r) && return oftype ((first (r) + last (r)) / 2 , NaN )
0 commit comments