Skip to content

Commit b16588b

Browse files
authored
Choose different quantile cutpoints in cut(x, n) (#416)
`Statistics.quantile` returns values which are not the most appropriate to generate labels. It is more intuitive to choose values from the actual data, which are likely to have fewer decimals and make more sense for users. Since intervals are closed on the left, we just have to use the value right below the quantile. This doesn't change group assignments (only labels).
1 parent 07b955f commit b16588b

File tree

2 files changed

+74
-31
lines changed

2 files changed

+74
-31
lines changed

src/extras.jl

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ default_formatter(from, to, i; leftclosed, rightclosed) =
4242
4343
Cut a numeric array into intervals at values `breaks`
4444
and return an ordered `CategoricalArray` indicating
45-
the interval into which each entry falls. Intervals are of the form `[lower, upper)`,
46-
i.e. the lower bound is included and the upper bound is excluded, except
45+
the interval into which each entry falls. Intervals are of the form `[lower, upper)`
46+
(closed on the left), i.e. the lower bound is included and the upper bound is excluded, except
4747
the last interval, which is closed on both ends, i.e. `[lower, upper]`.
4848
4949
If `x` accepts missing values (i.e. `eltype(x) >: Missing`) the returned array will
@@ -81,15 +81,15 @@ julia> cut(-1:0.5:1, 2)
8181
"Q1: [-1.0, 0.0)"
8282
"Q2: [0.0, 1.0]"
8383
"Q2: [0.0, 1.0]"
84-
"Q2: [0.0, 1.0]"
84+
"Q2: [0.0, 1.0]"
8585
8686
julia> cut(-1:0.5:1, 2, labels=["A", "B"])
8787
5-element CategoricalArray{String,1,UInt32}:
8888
"A"
8989
"A"
9090
"B"
9191
"B"
92-
"B"
92+
"B"
9393
9494
julia> cut(-1:0.5:1, 2, labels=[-0.5, +0.5])
9595
5-element CategoricalArray{Float64,1,UInt32}:
@@ -104,11 +104,11 @@ fmt (generic function with 1 method)
104104
105105
julia> cut(-1:0.5:1, 3, labels=fmt)
106106
5-element CategoricalArray{String,1,UInt32}:
107-
"grp 1 (-1.0//-0.3333333333333335)"
108-
"grp 1 (-1.0//-0.3333333333333335)"
109-
"grp 2 (-0.3333333333333335//0.33333333333333326)"
110-
"grp 3 (0.33333333333333326//1.0)"
111-
"grp 3 (0.33333333333333326//1.0)"
107+
"grp 1 (-1.0//0.0)"
108+
"grp 1 (-1.0//0.0)"
109+
"grp 2 (0.0//0.5)"
110+
"grp 3 (0.5//1.0)"
111+
"grp 3 (0.5//1.0)"
112112
```
113113
"""
114114
@inline function cut(x::AbstractArray, breaks::AbstractVector;
@@ -221,12 +221,38 @@ Provide the default label format for the `cut(x, ngroups)` method.
221221
quantile_formatter(from, to, i; leftclosed, rightclosed) =
222222
string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")
223223

224+
"""
225+
Find first value in (sorted) `v` which is greater than or equal to each quantile
226+
in (sorted) `qs`.
227+
"""
228+
function find_breaks(v::AbstractVector, qs::AbstractVector)
229+
n = length(qs)
230+
breaks = similar(v, n)
231+
n == 0 && return breaks
232+
233+
i = 1
234+
q = qs[1]
235+
@inbounds for x in v
236+
# Use isless and isequal to differentiate -0.0 from 0.0
237+
if isless(q, x) || isequal(q, x)
238+
breaks[i] = x
239+
i += 1
240+
i > n && break
241+
q = qs[i]
242+
end
243+
end
244+
return breaks
245+
end
246+
224247
"""
225248
cut(x::AbstractArray, ngroups::Integer;
226249
labels::Union{AbstractVector{<:AbstractString},Function},
227250
allowempty::Bool=false)
228251
229-
Cut a numeric array into `ngroups` quantiles, determined using `quantile`.
252+
Cut a numeric array into `ngroups` quantiles.
253+
254+
This is equivalent to `cut(x, quantile(x, (0:ngroups)/ngroups))`,
255+
but breaks are taken from actual data values instead of estimated quantiles.
230256
231257
If `x` contains `missing` values, they are automatically skipped when computing
232258
quantiles.
@@ -246,15 +272,14 @@ function cut(x::AbstractArray, ngroups::Integer;
246272
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
247273
allowempty::Bool=false)
248274
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
249-
xnm = eltype(x) >: Missing ? skipmissing(x) : x
250-
# Computing extrema is faster than taking 0 and 1 quantiles
251-
min_x, max_x = extrema(xnm)
275+
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
276+
min_x, max_x = first(sorted_x), last(sorted_x)
252277
if (min_x isa Number && isnan(min_x)) ||
253278
(max_x isa Number && isnan(max_x))
254279
throw(ArgumentError("NaN values are not allowed in input vector"))
255280
end
256-
breaks = quantile(xnm, (1:ngroups-1)/ngroups)
257-
breaks = [min_x; breaks; max_x]
281+
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
282+
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
258283
if !allowempty && !allunique(@view breaks[1:end-1])
259284
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
260285
"too many duplicated values in `x`. " *

test/15_extras.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,18 @@ end
127127

128128
@testset "cut([5, 4, 3, 2], 2)" begin
129129
x = @inferred cut([5, 4, 3, 2], 2)
130-
@test x == ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", "Q1: [2.0, 3.5)"]
130+
@test x == ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", "Q1: [2, 4)"]
131131
@test isa(x, CategoricalArray)
132132
@test isordered(x)
133-
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
133+
@test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"]
134134
end
135135

136136
@testset "cut(x, n) with missing values" begin
137137
x = @inferred cut([5, 4, 3, missing, 2], 2)
138-
@test x ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", missing, "Q1: [2.0, 3.5)"]
138+
@test x ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", missing, "Q1: [2, 4)"]
139139
@test isa(x, CategoricalArray)
140140
@test isordered(x)
141-
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
141+
@test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"]
142142
end
143143

144144
@testset "cut(x, n) with invalid n" begin
@@ -255,20 +255,29 @@ end
255255
@test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt)
256256

257257
@test_throws ArgumentError cut([fill(1, 10); 4], 2)
258-
@test_throws ArgumentError cut([fill(1, 10); 4], 3)
259258
x = cut([fill(1, 10); 4], 2, allowempty=true)
260-
@test unique(x) == ["Q2: [1.0, 4.0]"]
259+
@test unique(x) == ["Q2: [1, 4]"]
260+
@test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4]"]
261+
@test_throws ArgumentError cut([fill(1, 10); 4], 3)
261262
x = cut([fill(1, 10); 4], 3, allowempty=true)
262-
@test unique(x) == ["Q3: [1.0, 4.0]"]
263-
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: (1.0, 1.0)", "Q3: [1.0, 4.0]"]
263+
@test unique(x) == ["Q3: [1, 4]"]
264+
@test levels(x) == ["Q1: (1, 1)", "Q2: (1, 1)", "Q3: [1, 4]"]
265+
266+
x = cut([fill(4, 10); 1], 2)
267+
@test x == [fill("Q2: [4, 4]", 10); "Q1: [1, 4)"]
268+
@test levels(x) == ["Q1: [1, 4)"; "Q2: [4, 4]"]
269+
@test_throws ArgumentError cut([fill(4, 10); 1], 3)
270+
x = cut([fill(4, 10); 1], 3, allowempty=true)
271+
@test x == [fill("Q3: [4, 4]", 10); "Q1: [1, 4)"]
272+
@test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"]
264273

265274
x = cut([fill(1, 5); fill(4, 5)], 2)
266-
@test x == [fill("Q1: [1.0, 2.5)", 5); fill("Q2: [2.5, 4.0]", 5)]
267-
@test levels(x) == ["Q1: [1.0, 2.5)", "Q2: [2.5, 4.0]"]
275+
@test x == [fill("Q1: [1, 4)", 5); fill("Q2: [4, 4]", 5)]
276+
@test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"]
268277
@test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3)
269278
x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true)
270-
@test x == [fill("Q2: [1.0, 4.0)", 5); fill("Q3: [4.0, 4.0]", 5)]
271-
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: [1.0, 4.0)", "Q3: [4.0, 4.0]"]
279+
@test x == [fill("Q2: [1, 4)", 5); fill("Q3: [4, 4]", 5)]
280+
@test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4)", "Q3: [4, 4]"]
272281
end
273282

274283
@testset "cut with -0.0" begin
@@ -353,12 +362,21 @@ end
353362
@test levels(x) == ["[-Inf, 2.0)", "[2.0, 5.0]"]
354363

355364
x = cut([1:5; Inf], 2)
356-
@test x [fill("Q1: [1.0, 3.5)", 3); fill("Q2: [3.5, Inf]", 3)]
357-
@test levels(x) == ["Q1: [1.0, 3.5)", "Q2: [3.5, Inf]"]
365+
@test x [fill("Q1: [1.0, 4.0)", 3); fill("Q2: [4.0, Inf]", 3)]
366+
@test levels(x) == ["Q1: [1.0, 4.0)", "Q2: [4.0, Inf]"]
358367

359368
x = cut([1:5; -Inf], 2)
360-
@test x [fill("Q1: [-Inf, 2.5)", 2); fill("Q2: [2.5, 5.0]", 3); "Q1: [-Inf, 2.5)"]
361-
@test levels(x) == ["Q1: [-Inf, 2.5)", "Q2: [2.5, 5.0]"]
369+
@test x [fill("Q1: [-Inf, 3.0)", 2); fill("Q2: [3.0, 5.0]", 3); "Q1: [-Inf, 3.0)"]
370+
@test levels(x) == ["Q1: [-Inf, 3.0)", "Q2: [3.0, 5.0]"]
371+
end
372+
373+
@testset "cut when quantile falls exactly on a data value" begin
374+
x = cut([11, 14, 43, 54, 54, 56, 73, 79, 84, 84], 3)
375+
@test x ==
376+
["Q1: [11, 54)", "Q1: [11, 54)", "Q1: [11, 54)",
377+
"Q2: [54, 73)", "Q2: [54, 73)", "Q2: [54, 73)",
378+
"Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]"]
379+
@test levels(x) == ["Q1: [11, 54)", "Q2: [54, 73)", "Q3: [73, 84]"]
362380
end
363381

364382
end

0 commit comments

Comments
 (0)