Skip to content

Commit 854a541

Browse files
authored
Adds a weighted mode call (#611)
* Adds a weighted mode call. Use `sort` to make `modes` result consistent on different test architectures. Mention weighted option in mode docstrings. * Apply suggestions from code review. Co-authored-by: Milan Bouchet-Valat <[email protected]>
1 parent 8075e28 commit 854a541

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

src/scalarstats.jl

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ end
4747
# compute mode, given the range of integer values
4848
"""
4949
mode(a, [r])
50+
mode(a::AbstractArray, wv::AbstractWeights)
5051
5152
Return the mode (most common number) of an array, optionally
52-
over a specified range `r`. If several modes exist, the first
53-
one (in order of appearance) is returned.
53+
over a specified range `r` or weighted via a vector `wv`.
54+
If several modes exist, the first one (in order of appearance) is returned.
5455
"""
5556
function mode(a::AbstractArray{T}, r::UnitRange{T}) where T<:Integer
5657
isempty(a) && throw(ArgumentError("mode is not defined for empty collections"))
@@ -75,9 +76,10 @@ end
7576

7677
"""
7778
modes(a, [r])::Vector
79+
mode(a::AbstractArray, wv::AbstractWeights)::Vector
7880
7981
Return all modes (most common numbers) of an array, optionally over a
80-
specified range `r`.
82+
specified range `r` or weighted via vector `wv`.
8183
"""
8284
function modes(a::AbstractArray{T}, r::UnitRange{T}) where T<:Integer
8385
r0 = r[1]
@@ -158,6 +160,47 @@ function modes(a)
158160
return [x for (x, c) in cnts if c == mc]
159161
end
160162

163+
# Weighted mode of arbitrary vectors of values
164+
function mode(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real
165+
isempty(a) && throw(ArgumentError("mode is not defined for empty collections"))
166+
length(a) == length(wv) ||
167+
throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))"))
168+
169+
# Iterate through the data
170+
mv = first(a)
171+
mw = first(wv)
172+
weights = Dict{eltype(a), T}()
173+
for (x, w) in zip(a, wv)
174+
_w = get!(weights, x, zero(T)) + w
175+
if _w > mw
176+
mv = x
177+
mw = _w
178+
end
179+
weights[x] = _w
180+
end
181+
182+
return mv
183+
end
184+
185+
function modes(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real
186+
isempty(a) && throw(ArgumentError("mode is not defined for empty collections"))
187+
length(a) == length(wv) ||
188+
throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))"))
189+
190+
# Iterate through the data
191+
mw = first(wv)
192+
weights = Dict{eltype(a), T}()
193+
for (x, w) in zip(a, wv)
194+
_w = get!(weights, x, zero(T)) + w
195+
if _w > mw
196+
mw = _w
197+
end
198+
weights[x] = _w
199+
end
200+
201+
# find values corresponding to maximum counts
202+
return [x for (x, w) in weights if w == mw]
203+
end
161204

162205
#############################
163206
#

test/scalarstats.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,24 @@ using Statistics
4444
@test modes(skipmissing([1, missing, missing, 3, 2, 2, missing])) == [2]
4545
@test sort(modes(skipmissing([1, missing, 3, 3, 2, 2, missing]))) == [2, 3]
4646

47+
d1 = [1, 2, 3, 3, 4, 5, 5, 3]
48+
d2 = ['a', 'b', 'c', 'c', 'd', 'e', 'e', 'c']
49+
wv = weights([0.1:0.1:0.7; 0.1])
50+
@test mode(d1) == 3
51+
@test mode(d2) == 'c'
52+
@test mode(d1, wv) == 5
53+
@test mode(d2, wv) == 'e'
54+
@test sort(modes(d1[1:end-1], weights(ones(7)))) == [3, 5]
55+
@test sort(modes(d1, weights([.9, .1, .1, .1, .9, .1, .1, .1]))) == [1, 4]
56+
4757
@test_throws ArgumentError mode(Int[])
4858
@test_throws ArgumentError modes(Int[])
4959
@test_throws ArgumentError mode(Any[])
5060
@test_throws ArgumentError modes(Any[])
61+
@test_throws ArgumentError mode([], weights(Float64[]))
62+
@test_throws ArgumentError modes([], weights(Float64[]))
63+
@test_throws ArgumentError mode([1, 2, 3], weights([0.1, 0.3]))
64+
@test_throws ArgumentError modes([1, 2, 3], weights([0.1, 0.3]))
5165

5266
## zscores
5367

0 commit comments

Comments
 (0)