Skip to content

Commit bbed3a2

Browse files
committed
Improve choice of cutpoints
1 parent e0fe39c commit bbed3a2

File tree

2 files changed

+37
-31
lines changed

2 files changed

+37
-31
lines changed

src/extras.jl

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,25 @@ julia> cut(-1:0.5:1, [0, 1], extend=true)
7777
7878
julia> cut(-1:0.5:1, 2)
7979
5-element CategoricalArray{String,1,UInt32}:
80-
"Q1: [-1.0, 0.0)"
81-
"Q1: [-1.0, 0.0)"
82-
"Q2: [0.0, 1.0]"
83-
"Q2: [0.0, 1.0]"
84-
"Q2: [0.0, 1.0]"
80+
"Q1: [-1.0, 0.5)"
81+
"Q1: [-1.0, 0.5)"
82+
"Q1: [-1.0, 0.5)"
83+
"Q2: [0.5, 1.0]"
84+
"Q2: [0.5, 1.0]"
8585
8686
julia> cut(-1:0.5:1, 2, labels=["A", "B"])
8787
5-element CategoricalArray{String,1,UInt32}:
8888
"A"
8989
"A"
90-
"B"
90+
"A"
9191
"B"
9292
"B"
9393
9494
julia> cut(-1:0.5:1, 2, labels=[-0.5, +0.5])
9595
5-element CategoricalArray{Float64,1,UInt32}:
9696
-0.5
9797
-0.5
98-
0.5
98+
-0.5
9999
0.5
100100
0.5
101101
@@ -104,11 +104,11 @@ fmt (generic function with 1 method)
104104
105105
julia> cut(-1:0.5:1, 3, labels=fmt)
106106
5-element CategoricalArray{String,1,UInt32}:
107-
"grp 1 (-1.0//-0.3333333333333335)"
108-
"grp 1 (-1.0//-0.3333333333333335)"
109-
"grp 2 (-0.3333333333333335//0.33333333333333326)"
110-
"grp 3 (0.33333333333333326//1.0)"
111-
"grp 3 (0.33333333333333326//1.0)"
107+
"grp 1 (-1.0//0.0)"
108+
"grp 1 (-1.0//0.0)"
109+
"grp 2 (0.0//1.0)"
110+
"grp 2 (0.0//1.0)"
111+
"grp 3 (1.0//1.0)"
112112
```
113113
"""
114114
@inline function cut(x::AbstractArray, breaks::AbstractVector;
@@ -233,17 +233,21 @@ Provide the default label format for the `cut(x, ngroups)` method.
233233
quantile_formatter(from, to, i; leftclosed, rightclosed) =
234234
string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")
235235

236-
function _quantile!(v::AbstractVector, p::AbstractVector)
236+
function _quantile!(v::AbstractVector, ps::AbstractVector)
237237
n = length(v)
238238
n > 0 || throw(ArgumentError("cannot compute quantiles of empty data vector"))
239-
sort!(v)
240-
return map(p) do i
241-
v[clamp(ceil(Int, n*i), 0, n-1) + firstindex(v)]
239+
return map(ps) do p
240+
i = clamp(ceil(Int, n*p), 1, n) + firstindex(v) - 1
241+
q = v[i]
242+
# Take next distinct value even if quantile falls in a series of duplicated values
243+
@inbounds for j in (i+1):lastindex(v)
244+
q_prev = q
245+
q = v[j]
246+
q_prev != q && break
247+
end
248+
return q
242249
end
243250
end
244-
_quantile(x::AbstractArray, p::AbstractVector) =
245-
_quantile!(Base.copymutable(vec(x)), p)
246-
_quantile(x, p::AbstractVector) = _quantile!(collect(x), p)
247251

248252
"""
249253
cut(x::AbstractArray, ngroups::Integer;
@@ -253,7 +257,9 @@ _quantile(x, p::AbstractVector) = _quantile!(collect(x), p)
253257
Cut a numeric array into `ngroups` quantiles.
254258
255259
Cutpoints differ from those returned by `Statistics.quantile` as they are suited
256-
for intervals closed on the left and taken from actual values in `x`.
260+
for intervals closed on the left and taken from actual values in `x`. However,
261+
group assignments are identical to those which would be obtained with type 1
262+
quantiles if intervals were closed on the right.
257263
258264
If `x` contains `missing` values, they are automatically skipped when computing
259265
quantiles.
@@ -273,14 +279,14 @@ function cut(x::AbstractArray, ngroups::Integer;
273279
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
274280
allowempty::Bool=false)
275281
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
276-
xnm = eltype(x) >: Missing ? skipmissing(x) : x
277-
# Computing extrema is faster than taking 0 and 1 quantiles
278-
min_x, max_x = extrema(xnm)
282+
xnm = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
283+
min_x, max_x = first(xnm), last(xnm)
279284
if (min_x isa Number && isnan(min_x)) ||
280285
(max_x isa Number && isnan(max_x))
281286
throw(ArgumentError("NaN values are not allowed in input vector"))
282287
end
283-
breaks = _quantile(xnm, (0:ngroups)/ngroups)
288+
qs = _quantile!(xnm, (1:(ngroups-1))/ngroups)
289+
breaks = [min_x; qs; max_x]
284290
if !allowempty && !allunique(@view breaks[1:end-1])
285291
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
286292
"too many duplicated values in `x`. " *

test/15_extras.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,21 +254,21 @@ end
254254
fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0)
255255
@test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt)
256256

257-
@test_throws ArgumentError cut([fill(1, 10); 4], 2)
257+
x = cut([fill(1, 10); 4], 2)
258+
@test x == [fill("Q1: [1, 4)", 10); "Q2: [4, 4]"]
259+
@test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"]
258260
@test_throws ArgumentError cut([fill(1, 10); 4], 3)
259-
x = cut([fill(1, 10); 4], 2, allowempty=true)
260-
@test unique(x) == ["Q2: [1, 4]"]
261261
x = cut([fill(1, 10); 4], 3, allowempty=true)
262-
@test unique(x) == ["Q3: [1, 4]"]
263-
@test levels(x) == ["Q1: (1, 1)", "Q2: (1, 1)", "Q3: [1, 4]"]
262+
@test x == [fill("Q1: [1, 4)", 10); "Q3: [4, 4]"]
263+
@test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"]
264264

265265
x = cut([fill(1, 5); fill(4, 5)], 2)
266266
@test x == [fill("Q1: [1, 4)", 5); fill("Q2: [4, 4]", 5)]
267267
@test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"]
268268
@test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3)
269269
x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true)
270-
@test x == [fill("Q2: [1, 4)", 5); fill("Q3: [4, 4]", 5)]
271-
@test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4)", "Q3: [4, 4]"]
270+
@test x == [fill("Q1: [1, 4)", 5); fill("Q3: [4, 4]", 5)]
271+
@test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"]
272272
end
273273

274274
@testset "cut with -0.0" begin

0 commit comments

Comments
 (0)