Skip to content

Commit 1823086

Browse files
authored
Allow weighted means with arbitrary eltype (#476)
PR 442 introduced a restriction of the input element type to be <:Number, which means that the method doesn't get chosen by dispatch for e.g. inputs with missing values. This leads to bizarre behavior, as it dispatches to a nonsensical method where the array is treated as a function.
1 parent 24a80ca commit 1823086

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/weights.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,11 @@ w = rand(n)
466466
mean(x, weights(w))
467467
```
468468
"""
469-
mean(A::AbstractArray{T}, w::AbstractWeights{W};
470-
dims::Union{Nothing,Int}=nothing) where {T<:Number,W<:Real} = _mean(A, w, dims)
471-
_mean(A::AbstractArray{T}, w::AbstractWeights{W}, dims::Nothing) where {T<:Number,W<:Real} =
469+
mean(A::AbstractArray, w::AbstractWeights; dims::Union{Nothing,Int}=nothing) =
470+
_mean(A, w, dims)
471+
_mean(A::AbstractArray, w::AbstractWeights, dims::Nothing) =
472472
sum(A, w) / sum(w)
473-
_mean(A::AbstractArray{T}, w::AbstractWeights{W}, dims::Int) where {T<:Number,W<:Real} =
473+
_mean(A::AbstractArray{T}, w::AbstractWeights{W}, dims::Int) where {T,W} =
474474
_mean!(similar(A, wmeantype(T, W), Base.reduced_indices(axes(A), dims)), A, w, dims)
475475

476476

test/weights.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,4 +443,8 @@ end
443443
@test median(data, f(wt)) quantile(data, f(wt), 0.5) atol = 1e-5
444444
end
445445

446+
@testset "Mismatched eltypes" begin
447+
@test round(mean(Union{Int,Missing}[1,2], weights([1,2])), digits=3) 1.667
448+
end
449+
446450
end # @testset StatsBase.Weights

0 commit comments

Comments
 (0)