Skip to content

Commit f0cccd6

Browse files
authored
make counting more robust to input datatype (#722)
1 parent 1815e1e commit f0cccd6

File tree

5 files changed

+304
-228
lines changed

5 files changed

+304
-228
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ julia = "1"
2828
[extras]
2929
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
3030
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
31+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3132
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3233
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3334

3435
[targets]
35-
test = ["Dates", "DelimitedFiles", "StableRNGs", "Test"]
36+
test = ["Dates", "DelimitedFiles", "OffsetArrays", "StableRNGs", "Test"]

src/counts.jl

Lines changed: 106 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,24 @@ end
1616

1717
#### functions for counting a single list of integers (1D)
1818
"""
19-
addcounts!(r, x, levels::UnitRange{<:Int}, [wv::AbstractWeights])
19+
addcounts!(r, x, levels::UnitRange{<:Integer}, [wv::AbstractWeights])
2020
2121
Add the number of occurrences in `x` of each value in `levels` to an existing
22-
array `r`. If a weighting vector `wv` is specified, the sum of weights is used
23-
rather than the raw counts.
22+
array `r`. For each `xi ∈ x`, if `xi == levels[j]`, then we increment `r[j]`.
23+
24+
If a weighting vector `wv` is specified, the sum of weights is used rather than the
25+
raw counts.
2426
"""
2527
function addcounts!(r::AbstractArray, x::IntegerArray, levels::IntUnitRange)
26-
# add counts of integers from x to r
28+
# add counts of integers from x that fall within levels to r
2729

28-
k = length(levels)
29-
length(r) == k || throw(DimensionMismatch())
30+
checkbounds(r, axes(levels)...)
3031

31-
m0 = levels[1]
32-
m1 = levels[end]
33-
b = m0 - 1
32+
m0 = first(levels)
33+
m1 = last(levels)
34+
b = m0 - firstindex(levels) # firstindex(levels) == 1 because levels::IntUnitRange
3435

35-
@inbounds for i in 1 : length(x)
36-
xi = x[i]
36+
@inbounds for xi in x
3737
if m0 <= xi <= m1
3838
r[xi - b] += 1
3939
end
@@ -42,15 +42,21 @@ function addcounts!(r::AbstractArray, x::IntegerArray, levels::IntUnitRange)
4242
end
4343

4444
function addcounts!(r::AbstractArray, x::IntegerArray, levels::IntUnitRange, wv::AbstractWeights)
45-
k = length(levels)
46-
length(r) == k || throw(DimensionMismatch())
45+
# add wv weighted counts of integers from x that fall within levels to r
46+
47+
length(x) == length(wv) ||
48+
throw(DimensionMismatch("x and wv must have the same length, got $(length(x)) and $(length(wv))"))
49+
50+
xv = vec(x) # discard shape because weights() discards shape
51+
52+
checkbounds(r, axes(levels)...)
4753

48-
m0 = levels[1]
49-
m1 = levels[end]
54+
m0 = first(levels)
55+
m1 = last(levels)
5056
b = m0 - 1
5157

52-
@inbounds for i in 1 : length(x)
53-
xi = x[i]
58+
@inbounds for i in eachindex(xv, wv)
59+
xi = xv[i]
5460
if m0 <= xi <= m1
5561
r[xi - b] += wv[i]
5662
end
@@ -69,8 +75,8 @@ falling in that range will be considered (the others will be ignored without
6975
raising an error or a warning). If an integer `k` is provided, only values in the
7076
range `1:k` will be considered.
7177
72-
If a weighting vector `wv` is specified, the sum of the weights is used rather than the
73-
raw counts.
78+
If a vector of weights `wv` is provided, the proportion of weights is computed rather
79+
than the proportion of raw counts.
7480
7581
The output is a vector of length `length(levels)`.
7682
"""
@@ -90,8 +96,10 @@ counts(x::IntegerArray, wv::AbstractWeights) = counts(x, span(x), wv)
9096
proportions(x, levels=span(x), [wv::AbstractWeights])
9197
9298
Return the proportion of values in the range `levels` that occur in `x`.
93-
Equivalent to `counts(x, levels) / length(x)`. If a weighting vector `wv`
94-
is specified, the sum of the weights is used rather than the raw counts.
99+
Equivalent to `counts(x, levels) / length(x)`.
100+
101+
If a vector of weights `wv` is provided, the proportion of weights is computed rather
102+
than the proportion of raw counts.
95103
"""
96104
proportions(x::IntegerArray, levels::IntUnitRange) = counts(x, levels) .* inv(length(x))
97105
proportions(x::IntegerArray, levels::IntUnitRange, wv::AbstractWeights) =
@@ -101,6 +109,9 @@ proportions(x::IntegerArray, levels::IntUnitRange, wv::AbstractWeights) =
101109
proportions(x, k::Integer, [wv::AbstractWeights])
102110
103111
Return the proportion of integers in 1 to `k` that occur in `x`.
112+
113+
If a vector of weights `wv` is provided, the proportion of weights is computed rather
114+
than the proportion of raw counts.
104115
"""
105116
proportions(x::IntegerArray, k::Integer) = proportions(x, 1:k)
106117
proportions(x::IntegerArray, k::Integer, wv::AbstractWeights) = proportions(x, 1:k, wv)
@@ -110,26 +121,22 @@ proportions(x::IntegerArray, wv::AbstractWeights) = proportions(x, span(x), wv)
110121
#### functions for counting a single list of integers (2D)
111122

112123
function addcounts!(r::AbstractArray, x::IntegerArray, y::IntegerArray, levels::NTuple{2,IntUnitRange})
113-
# add counts of integers from x to r
114-
115-
n = length(x)
116-
length(y) == n || throw(DimensionMismatch())
124+
# add counts of pairs from zip(x,y) to r
117125

118126
xlevels, ylevels = levels
119127

120-
kx = length(xlevels)
121-
ky = length(ylevels)
122-
size(r) == (kx, ky) || throw(DimensionMismatch())
123128

124-
mx0 = xlevels[1]
125-
mx1 = xlevels[end]
126-
my0 = ylevels[1]
127-
my1 = ylevels[end]
129+
checkbounds(r, axes(xlevels, 1), axes(ylevels, 1))
130+
131+
mx0 = first(xlevels)
132+
mx1 = last(xlevels)
133+
my0 = first(ylevels)
134+
my1 = last(ylevels)
128135

129136
bx = mx0 - 1
130137
by = my0 - 1
131138

132-
for i = 1:n
139+
for i in eachindex(vec(x), vec(y))
133140
xi = x[i]
134141
yi = y[i]
135142
if (mx0 <= xi <= mx1) && (my0 <= yi <= my1)
@@ -141,28 +148,31 @@ end
141148

142149
function addcounts!(r::AbstractArray, x::IntegerArray, y::IntegerArray,
143150
levels::NTuple{2,IntUnitRange}, wv::AbstractWeights)
144-
# add counts of integers from x to r
151+
# add counts of pairs from zip(x,y) to r
152+
153+
length(x) == length(y) == length(wv) ||
154+
throw(DimensionMismatch("x, y, and wv must have the same length, but got $(length(x)), $(length(y)), and $(length(wv))"))
145155

146-
n = length(x)
147-
length(y) == length(wv) == n || throw(DimensionMismatch())
156+
axes(x) == axes(y) ||
157+
throw(DimensionMismatch("x and y must have the same axes, but got $(axes(x)) and $(axes(y))"))
158+
159+
xv, yv = vec(x), vec(y) # discard shape because weights() discards shape
148160

149161
xlevels, ylevels = levels
150162

151-
kx = length(xlevels)
152-
ky = length(ylevels)
153-
size(r) == (kx, ky) || throw(DimensionMismatch())
163+
checkbounds(r, axes(xlevels, 1), axes(ylevels, 1))
154164

155-
mx0 = xlevels[1]
156-
mx1 = xlevels[end]
157-
my0 = ylevels[1]
158-
my1 = ylevels[end]
165+
mx0 = first(xlevels)
166+
mx1 = last(xlevels)
167+
my0 = first(ylevels)
168+
my1 = last(ylevels)
159169

160170
bx = mx0 - 1
161171
by = my0 - 1
162172

163-
for i = 1:n
164-
xi = x[i]
165-
yi = y[i]
173+
for i in eachindex(xv, yv, wv)
174+
xi = xv[i]
175+
yi = yv[i]
166176
if (mx0 <= xi <= mx1) && (my0 <= yi <= my1)
167177
r[xi - bx, yi - by] += wv[i]
168178
end
@@ -235,13 +245,15 @@ end
235245

236246

237247
"""
238-
addcounts!(dict, x[, wv]; alg = :auto)
248+
addcounts!(dict, x; alg = :auto)
249+
addcounts!(dict, x, wv)
239250
240251
Add counts based on `x` to a count map. New entries will be added if new values come up.
252+
241253
If a weighting vector `wv` is specified, the sum of the weights is used rather than the
242254
raw counts.
243255
244-
`alg` can be one of:
256+
`alg` is only allowed for unweighted counting and can be one of:
245257
- `:auto` (default): if `StatsBase.radixsort_safe(eltype(x)) == true` then use
246258
`:radixsort`, otherwise use `:dict`.
247259
@@ -284,9 +296,9 @@ function addcounts_dict!(cm::Dict{T}, x) where T
284296
end
285297

286298
# If the bits type is of small size i.e. it can have up to 65536 distinct values
287-
# then it is always better to apply a counting-sort like reduce algorithm for
299+
# then it is always better to apply a counting-sort like reduce algorithm for
288300
# faster results and less memory usage. However we still wish to enable others
289-
# to write generic algorithms, therefore the methods below still accept the
301+
# to write generic algorithms, therefore the methods below still accept the
290302
# `alg` argument but it is ignored.
291303
function _addcounts!(::Type{Bool}, cm::Dict{Bool}, x::AbstractArray{Bool}; alg = :ignored)
292304
sumx = sum(x)
@@ -335,32 +347,42 @@ const BaseRadixSortSafeTypes = Union{Int8, Int16, Int32, Int64, Int128,
335347
"Can the type be safely sorted by radixsort"
336348
radixsort_safe(::Type{T}) where T = T<:BaseRadixSortSafeTypes
337349

338-
function _addcounts_radix_sort_loop!(cm::Dict{T}, sx::AbstractArray{T}) where T
350+
function _addcounts_radix_sort_loop!(cm::Dict{T}, sx::AbstractVector{T}) where T
339351
isempty(sx) && return cm
340-
last_sx = sx[1]
341-
tmpcount = get(cm, last_sx, 0) + 1
352+
last_sx = first(sx)
353+
start_i = firstindex(sx)
342354

343355
# now the data is sorted: can just run through and accumulate values before
344356
# adding into the Dict
345-
@inbounds for i in 2:length(sx)
357+
@inbounds for i in start_i+1:lastindex(sx)
346358
sxi = sx[i]
347-
if last_sx == sxi
348-
tmpcount += 1
349-
else
350-
cm[last_sx] = tmpcount
359+
if last_sx != sxi
360+
cm[last_sx] = get(cm, last_sx, 0) + i - start_i
351361
last_sx = sxi
352-
tmpcount = get(cm, last_sx, 0) + 1
362+
start_i = i
353363
end
354364
end
355365

356-
cm[sx[end]] = tmpcount
366+
last_sx = last(sx)
367+
cm[last_sx] = get(cm, last_sx, 0) + lastindex(sx) + 1 - start_i
357368

358369
return cm
359370
end
360371

372+
function _alg(x::AbstractArray)
373+
@static if VERSION >= v"1.9.0-DEV"
374+
return Base.DEFAULT_UNSTABLE
375+
else
376+
firstindex(x) == 1 ||
377+
throw(ArgumentError("alg = :radixsort requires either one based indexing or Julia >= 1.9. " *
378+
"Use `alg = :dict` as an alternative."))
379+
return SortingAlgorithms.RadixSort
380+
end
381+
end
382+
361383
function addcounts_radixsort!(cm::Dict{T}, x::AbstractArray{T}) where T
362384
# sort the x using radixsort
363-
sx = sort(x, alg = RadixSort)
385+
sx = sort(vec(x), alg=_alg(x))
364386

365387
# Delegate the loop to a separate function since sort might not
366388
# be inferred in Julia 0.6 after SortingAlgorithms is loaded.
@@ -369,18 +391,24 @@ function addcounts_radixsort!(cm::Dict{T}, x::AbstractArray{T}) where T
369391
end
370392

371393
# fall-back for `x` an iterator
372-
function addcounts_radixsort!(cm::Dict{T}, x) where T
373-
sx = sort!(collect(x), alg = RadixSort)
394+
function addcounts_radixsort!(cm::Dict{T}, x) where T
395+
cx = vec(collect(x))
396+
sx = sort!(cx, alg = _alg(cx))
374397
return _addcounts_radix_sort_loop!(cm, sx)
375398
end
376399

377400
function addcounts!(cm::Dict{T}, x::AbstractArray{T}, wv::AbstractVector{W}) where {T,W<:Real}
378-
n = length(x)
379-
length(wv) == n || throw(DimensionMismatch())
401+
# add wv weighted counts of integers from x to cm
402+
403+
length(x) == length(wv) ||
404+
throw(DimensionMismatch("x and wv must have the same length, got $(length(x)) and $(length(wv))"))
405+
406+
xv = vec(x) # discard shape because weights() discards shape
407+
380408
z = zero(W)
381409

382-
for i = 1 : n
383-
@inbounds xi = x[i]
410+
for i in eachindex(xv, wv)
411+
@inbounds xi = xv[i]
384412
@inbounds wi = wv[i]
385413
cm[xi] = get(cm, xi, z) + wi
386414
end
@@ -390,11 +418,14 @@ end
390418

391419
"""
392420
countmap(x; alg = :auto)
393-
countmap(x::AbstractVector, w::AbstractVector{<:Real}; alg = :auto)
421+
countmap(x::AbstractVector, wv::AbstractVector{<:Real})
394422
395-
Return a dictionary mapping each unique value in `x` to its number
396-
of occurrences. A vector of weights `w` can be provided when `x` is a vector.
423+
Return a dictionary mapping each unique value in `x` to its number of occurrences.
397424
425+
If a weighting vector `wv` is specified, the sum of weights is used rather than the
426+
raw counts.
427+
428+
`alg` is only allowed for unweighted counting and can be one of:
398429
- `:auto` (default): if `StatsBase.radixsort_safe(eltype(x)) == true` then use
399430
`:radixsort`, otherwise use `:dict`.
400431
@@ -414,9 +445,12 @@ countmap(x::AbstractArray{T}, wv::AbstractVector{W}) where {T,W<:Real} = addcoun
414445

415446
"""
416447
proportionmap(x)
448+
proportionmap(x::AbstractVector, w::AbstractVector{<:Real})
449+
450+
Return a dictionary mapping each unique value in `x` to its proportion in `x`.
417451
418-
Return a dictionary mapping each unique value in `x` to its
419-
proportion in `x`.
452+
If a vector of weights `wv` is provided, the proportion of weights is computed rather
453+
than the proportion of raw counts.
420454
"""
421455
proportionmap(x::AbstractArray) = _normalize_countmap(countmap(x), length(x))
422456
proportionmap(x::AbstractArray, wv::AbstractWeights) = _normalize_countmap(countmap(x, wv), sum(wv))

src/weights.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ length(wv::AbstractWeights) = length(wv.values)
2121
sum(wv::AbstractWeights) = wv.sum
2222
isempty(wv::AbstractWeights) = isempty(wv.values)
2323
size(wv::AbstractWeights) = size(wv.values)
24+
Base.axes(wv::AbstractWeights) = Base.axes(wv.values)
2425

2526
Base.dataids(wv::AbstractWeights) = Base.dataids(wv.values)
2627

@@ -301,6 +302,7 @@ sum(wv::UnitWeights{T}) where T = convert(T, length(wv))
301302
isempty(wv::UnitWeights) = iszero(wv.len)
302303
length(wv::UnitWeights) = wv.len
303304
size(wv::UnitWeights) = tuple(length(wv))
305+
Base.axes(wv::UnitWeights) = tuple(Base.OneTo(length(wv)))
304306

305307
Base.convert(::Type{Vector}, wv::UnitWeights{T}) where {T} = ones(T, length(wv))
306308

0 commit comments

Comments
 (0)