Skip to content

Commit 591d045

Browse files
authored
Make countmap support iterators (#605)
1 parent 0ffe2e2 commit 591d045

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

src/counts.jl

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ raw counts.
255255
- `:dict`: use `Dict`-based method which is generally slower but uses less
256256
RAM and is safe for any data type.
257257
"""
258-
function addcounts!(cm::Dict{T}, x::AbstractArray{T}; alg = :auto) where T
258+
addcounts!(cm::Dict, x; alg = :auto) = _addcounts!(eltype(x), cm, x, alg = alg)
259+
260+
function _addcounts!(::Type{T}, cm::Dict, x; alg = :auto) where T
259261
# if it's safe to be sorted using radixsort then it should be faster
260262
# albeit using more RAM
261263
if radixsort_safe(T) && (alg == :auto || alg == :radixsort)
@@ -269,7 +271,7 @@ function addcounts!(cm::Dict{T}, x::AbstractArray{T}; alg = :auto) where T
269271
end
270272

271273
"""Dict-based addcounts method"""
272-
function addcounts_dict!(cm::Dict{T}, x::AbstractArray{T}) where T
274+
function addcounts_dict!(cm::Dict{T}, x) where T
273275
for v in x
274276
index = ht_keyindex2!(cm, v)
275277
if index > 0
@@ -286,14 +288,27 @@ end
286288
# faster results and less memory usage. However we still wish to enable others
287289
# to write generic algorithms, therefore the methods below still accept the
288290
# `alg` argument but it is ignored.
289-
function addcounts!(cm::Dict{Bool}, x::AbstractArray{Bool}; alg = :ignored)
291+
function _addcounts!(::Type{Bool}, cm::Dict{Bool}, x::AbstractArray{Bool}; alg = :ignored)
290292
sumx = sum(x)
291293
cm[true] = get(cm, true, 0) + sumx
292294
cm[false] = get(cm, false, 0) + length(x) - sumx
293295
cm
294296
end
295297

296-
function addcounts!(cm::Dict{T}, x::AbstractArray{T}; alg = :ignored) where T <: Union{UInt8, UInt16, Int8, Int16}
298+
# specialized for `Bool` iterator
299+
function _addcounts!(::Type{Bool}, cm::Dict{Bool}, x; alg = :ignored)
300+
sumx = 0
301+
len = 0
302+
for i in x
303+
sumx += i
304+
len += 1
305+
end
306+
cm[true] = get(cm, true, 0) + sumx
307+
cm[false] = get(cm, false, 0) + len - sumx
308+
cm
309+
end
310+
311+
function _addcounts!(::Type{T}, cm::Dict{T}, x; alg = :ignored) where T <: Union{UInt8, UInt16, Int8, Int16}
297312
counts = zeros(Int, 2^(8sizeof(T)))
298313

299314
@inbounds for xi in x
@@ -318,8 +333,7 @@ const BaseRadixSortSafeTypes = Union{Int8, Int16, Int32, Int64, Int128,
318333
Float32, Float64}
319334

320335
"Can the type be safely sorted by radixsort"
321-
radixsort_safe(::Type{T}) where {T<:BaseRadixSortSafeTypes} = true
322-
radixsort_safe(::Type) = false
336+
radixsort_safe(::Type{T}) where T = T<:BaseRadixSortSafeTypes
323337

324338
function _addcounts_radix_sort_loop!(cm::Dict{T}, sx::AbstractArray{T}) where T
325339
last_sx = sx[1]
@@ -353,6 +367,12 @@ function addcounts_radixsort!(cm::Dict{T}, x::AbstractArray{T}) where T
353367
return _addcounts_radix_sort_loop!(cm, sx)
354368
end
355369

370+
# fall-back for `x` an iterator
371+
function addcounts_radixsort!(cm::Dict{T}, x) where T
372+
sx = sort!(collect(x), alg = RadixSort)
373+
return _addcounts_radix_sort_loop!(cm, sx)
374+
end
375+
356376
function addcounts!(cm::Dict{T}, x::AbstractArray{T}, wv::AbstractVector{W}) where {T,W<:Real}
357377
n = length(x)
358378
length(wv) == n || throw(DimensionMismatch())
@@ -386,7 +406,7 @@ of occurrences.
386406
- `:dict`: use `Dict`-based method which is generally slower but uses less
387407
RAM and is safe for any data type.
388408
"""
389-
countmap(x::AbstractArray{T}; alg = :auto) where {T} = addcounts!(Dict{T,Int}(), x; alg = alg)
409+
countmap(x; alg = :auto) = addcounts!(Dict{eltype(x),Int}(), x; alg = alg)
390410
countmap(x::AbstractArray{T}, wv::AbstractVector{W}) where {T,W<:Real} = addcounts!(Dict{T,W}(), x, wv)
391411

392412

test/counts.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ cm = countmap(x)
8080
@test cm["a"] == 3
8181
@test cm["b"] == 2
8282
@test cm["c"] == 1
83+
84+
# iterator, non-radixsort
85+
cm_missing = countmap(skipmissing(x))
86+
cm_any_itr = countmap((i for i in x))
87+
@test cm_missing == cm_any_itr == cm
88+
@test cm_missing isa Dict{String, Int}
89+
@test cm_any_itr isa Dict{Any, Int}
90+
8391
pm = proportionmap(x)
8492
@test pm["a"] (1/2)
8593
@test pm["b"] (1/3)
@@ -91,6 +99,15 @@ xx = repeat([6, 1, 3, 1], outer=100_000)
9199
cm = countmap(xx)
92100
@test cm == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000)
93101

102+
# with iterator
103+
cm_missing = countmap(skipmissing(xx))
104+
@test cm_missing isa Dict{Int, Int}
105+
@test cm_missing == cm
106+
107+
cm_any_itr = countmap((i for i in xx))
108+
@test cm_any_itr isa Dict{Any,Int} # no knowledge about type
109+
@test cm_missing == cm
110+
94111
# testing the radixsort-based addcounts
95112
xx = repeat([6, 1, 3, 1], outer=100_000)
96113
cm = Dict{Int, Int}()
@@ -99,11 +116,20 @@ StatsBase.addcounts_radixsort!(cm,xx)
99116
xx2 = repeat([7, 1, 3, 1], outer=100_000)
100117
StatsBase.addcounts_radixsort!(cm,xx2)
101118
@test cm == Dict(1 => 400_000, 3 => 200_000, 6 => 100_000, 7 => 100_000)
119+
# with iterator
120+
cm_missing = Dict{Int, Int}()
121+
StatsBase.addcounts_radixsort!(cm_missing,skipmissing(xx))
122+
@test cm_missing == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000)
123+
StatsBase.addcounts_radixsort!(cm_missing,skipmissing(xx2))
124+
@test cm_missing == Dict(1 => 400_000, 3 => 200_000, 6 => 100_000, 7 => 100_000)
102125

103126
# testing the Dict-based addcounts
104127
cm = Dict{Int, Int}()
128+
cm_itr = Dict{Int, Int}()
105129
StatsBase.addcounts_dict!(cm,xx)
106-
@test cm == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000)
130+
StatsBase.addcounts_dict!(cm_itr,skipmissing(xx))
131+
@test cm_itr == cm == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000)
132+
@test cm_itr isa Dict{Int, Int}
107133

108134
cm = countmap(x, weights(w))
109135
@test cm["a"] == 5.5
@@ -119,11 +145,16 @@ pm = proportionmap(x, weights(w))
119145

120146
# testing small bits type
121147
bx = [true, false, true, true, false]
122-
@test countmap(bx) == Dict(true => 3, false => 2)
148+
cm_bx_missing = countmap(skipmissing(bx))
149+
@test cm_bx_missing == countmap(bx) == Dict(true => 3, false => 2)
150+
@test cm_bx_missing isa Dict{Bool, Int}
123151

124152
for T in [UInt8, UInt16, Int8, Int16]
125153
tx = T[typemin(T), 8, typemax(T), 19, 8]
126-
@test countmap(tx) == Dict(typemin(T) => 1, typemax(T) => 1, 8 => 2, 19 => 1)
154+
tx_missing = skipmissing(T[typemin(T), 8, typemax(T), 19, 8])
155+
cm_tx_missing = countmap(tx_missing)
156+
@test cm_tx_missing == countmap(tx) == Dict(typemin(T) => 1, typemax(T) => 1, 8 => 2, 19 => 1)
157+
@test cm_tx_missing isa Dict{T, Int}
127158
end
128159

129160
@testset "views" begin

0 commit comments

Comments
 (0)