Skip to content

Commit 341de70

Browse files
authored
Fix corner cases of cut (#410)
Apply more systematically the rule that all intervals are closed on the left and open on the right except the last one. Throw an error when duplicated breaks this would lead to empty intervals unless `allowempty=true`. Improve handling of -0.0, NaN and Inf.
1 parent 4434fe4 commit 341de70

File tree

3 files changed

+130
-29
lines changed

3 files changed

+130
-29
lines changed

src/extras.jl

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray,
99
@inbounds for i in eachindex(X)
1010
x = X[i]
1111

12-
if ismissing(x)
12+
if x isa Number && isnan(x)
13+
throw(ArgumentError("NaN values are not allowed in input vector"))
14+
elseif ismissing(x)
1315
refs[i] = 0
14-
elseif x == upper
16+
elseif isequal(x, upper)
1517
refs[i] = n-1
16-
elseif extend !== true && !(lower <= x <= upper)
18+
elseif extend !== true &&
19+
!((isless(lower, x) || isequal(x, lower)) && isless(x, upper))
1720
extend === missing ||
1821
throw(ArgumentError("value $x (at index $i) does not fall inside the breaks: " *
1922
"adapt them manually, or pass extend=true or extend=missing"))
@@ -55,10 +58,10 @@ also accept them.
5558
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
5659
the labels from the left and right interval boundaries and the group index. Defaults to
5760
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
58-
* `allowempty::Bool=false`: when `false`, an error is raised if some breaks appear
59-
multiple times, generating empty intervals; when `true`, duplicate breaks are allowed
60-
and the intervals they generate are kept as unused levels
61-
(but duplicate labels are not allowed).
61+
* `allowempty::Bool=false`: when `false`, an error is raised if some breaks other than
62+
the last one appear multiple times, generating empty intervals; when `true`,
63+
duplicate breaks are allowed and the intervals they generate are kept as
64+
unused levels (but duplicate labels are not allowed).
6265
6366
# Examples
6467
```jldoctest
@@ -132,14 +135,19 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
132135
extend::Union{Bool, Missing},
133136
labels::Union{AbstractVector{<:SupportedTypes},Function},
134137
allowempty::Bool=false) where {T, N}
135-
if !allowempty && !allunique(breaks)
136-
throw(ArgumentError("all breaks must be unique unless `allowempty=true`"))
137-
end
138-
139138
if !issorted(breaks)
140139
breaks = sort(breaks)
141140
end
142141

142+
if any(x -> x isa Number && isnan(x), breaks)
143+
throw(ArgumentError("NaN values are not allowed in breaks"))
144+
end
145+
146+
if !allowempty && !allunique(@view breaks[1:end-1])
147+
throw(ArgumentError("all breaks other than the last one must be unique " *
148+
"unless `allowempty=true`"))
149+
end
150+
143151
if extend === true
144152
xnm = T >: Missing ? skipmissing(x) : x
145153
length(breaks) >= 1 || throw(ArgumentError("at least one break must be provided"))
@@ -158,11 +166,11 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
158166
rethrow(err)
159167
end
160168
end
161-
if !ismissing(min_x) && breaks[1] > min_x
169+
if !ismissing(min_x) && isless(min_x, breaks[1])
162170
# this type annotation is needed on Julia<1.7 for stable inference
163171
breaks = [min_x::nonmissingtype(eltype(x)); breaks]
164172
end
165-
if !ismissing(max_x) && breaks[end] < max_x
173+
if !ismissing(max_x) && isless(breaks[end], max_x)
166174
breaks = [breaks; max_x::nonmissingtype(eltype(x))]
167175
end
168176
length(breaks) > 1 ||
@@ -189,16 +197,15 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
189197
from = breaks[1:n-1]
190198
to = breaks[2:n]
191199
firstlevel = labels(from[1], to[1], 1,
192-
leftclosed=breaks[1] != breaks[2], rightclosed=false)
200+
leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false)
193201
levs = Vector{typeof(firstlevel)}(undef, n-1)
194202
levs[1] = firstlevel
195203
for i in 2:n-2
196204
levs[i] = labels(from[i], to[i], i,
197-
leftclosed=breaks[i] != breaks[i+1], rightclosed=false)
205+
leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false)
198206
end
199207
levs[end] = labels(from[end], to[end], n-1,
200-
leftclosed=breaks[end-1] != breaks[end],
201-
rightclosed=true)
208+
leftclosed=true, rightclosed=true)
202209
else
203210
length(labels) == n-1 ||
204211
throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
@@ -243,21 +250,28 @@ quantiles.
243250
the labels from the left and right interval boundaries and the group index. Defaults to
244251
`"Qi: [from, to)"` (or `"Qi: [from, to]"` for the rightmost interval).
245252
* `allowempty::Bool=false`: when `false`, an error is raised if some quantiles breakpoints
246-
are equal, generating empty intervals; when `true`, duplicate breaks are allowed
247-
and the intervals they generate are kept as unused levels
248-
(but duplicate labels are not allowed).
253+
other than the last one are equal, generating empty intervals;
254+
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
255+
unused levels (but duplicate labels are not allowed).
249256
"""
250257
function cut(x::AbstractArray, ngroups::Integer;
251258
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
252259
allowempty::Bool=false)
260+
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
253261
xnm = eltype(x) >: Missing ? skipmissing(x) : x
254-
breaks = Statistics.quantile(xnm, (1:ngroups-1)/ngroups)
255-
if !allowempty && !allunique(breaks)
256-
n = length(unique(breaks)) - 1
257-
throw(ArgumentError("cannot compute $ngroups quantiles: `quantile` " *
258-
"returned only $n groups due to duplicated values in `x`." *
262+
# Computing extrema is faster than taking 0 and 1 quantiles
263+
min_x, max_x = extrema(xnm)
264+
if (min_x isa Number && isnan(min_x)) ||
265+
(max_x isa Number && isnan(max_x))
266+
throw(ArgumentError("NaN values are not allowed in input vector"))
267+
end
268+
breaks = quantile(xnm, (1:ngroups-1)/ngroups)
269+
breaks = [min_x; breaks; max_x]
270+
if !allowempty && !allunique(@view breaks[1:end-1])
271+
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
272+
"too many duplicated values in `x`. " *
259273
"Pass `allowempty=true` to allow empty quantiles or " *
260274
"choose a lower value for `ngroups`."))
261275
end
262-
cut(x, breaks; extend=true, labels=labels, allowempty=allowempty)
276+
cut(x, breaks; labels=labels, allowempty=allowempty)
263277
end

test/15_extras.jl

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,6 @@ const ≅ = isequal
111111
@test isa(x, CategoricalVector{Union{Int, String, T}})
112112
@test isordered(x)
113113
@test levels(x) == [0, "2", 4, "6", 8]
114-
115-
@test_throws ArgumentError cut([-0.0, 0.0], 2)
116-
@test_throws ArgumentError cut([-0.0, 0.0], 2, labels=[-0.0, 0.0])
117114
end
118115

119116
@testset "cut with missing values in input" begin
@@ -144,6 +141,11 @@ end
144141
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
145142
end
146143

144+
@testset "cut(x, n) with invalid n" begin
145+
@test_throws ArgumentError cut(1:10, 0)
146+
@test_throws ArgumentError cut(1:10, -1)
147+
end
148+
147149
@testset "cut with formatter function" begin
148150
my_formatter(from, to, i; leftclosed, rightclosed) = "$i: $from -- $to"
149151

@@ -185,11 +187,20 @@ end
185187
x = [zeros(10); ones(10)]
186188
@test_throws ArgumentError cut(x, [0, 0.1, 0.1, 10])
187189
@test_throws ArgumentError cut(x, 10)
190+
y = cut(x, [0, 0.1, 10, 10])
191+
@test y == [fill("[0.0, 0.1)", 10); fill("[0.1, 10.0)", 10)]
192+
@test levels(y) == ["[0.0, 0.1)", "[0.1, 10.0)", "[10.0, 10.0]"]
188193

189194
@test_throws ArgumentError cut(1:10, [1, 5, 5, 11])
190195
y = cut(1:10, [1, 5, 5, 11], allowempty=true)
191196
@test y == cut(1:10, [1, 5, 11])
192197
@test levels(y) == ["[1, 5)", "(5, 5)", "[5, 11]"]
198+
y = cut(1:10, [1, 5, 11, 11])
199+
@test y == [fill("[1, 5)", 4); fill("[5, 11)", 6)]
200+
@test levels(y) == ["[1, 5)", "[5, 11)", "[11, 11]"]
201+
y = cut(1:10, [1, 5, 10, 10])
202+
@test y == [fill("[1, 5)", 4); fill("[5, 10)", 5); "[10, 10]"]
203+
@test levels(y) == ["[1, 5)", "[5, 10)", "[10, 10]"]
193204

194205
@test_throws ArgumentError cut(1:10, [1, 5, 5, 5, 11])
195206
@test_throws ArgumentError cut(1:10, [1, 5, 5, 11],
@@ -242,6 +253,49 @@ end
242253

243254
fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0)
244255
@test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt)
256+
257+
@test_throws ArgumentError cut([fill(1, 10); 4], 2)
258+
@test_throws ArgumentError cut([fill(1, 10); 4], 3)
259+
x = cut([fill(1, 10); 4], 2, allowempty=true)
260+
@test unique(x) == ["Q2: [1.0, 4.0]"]
261+
x = cut([fill(1, 10); 4], 3, allowempty=true)
262+
@test unique(x) == ["Q3: [1.0, 4.0]"]
263+
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: (1.0, 1.0)", "Q3: [1.0, 4.0]"]
264+
265+
x = cut([fill(1, 5); fill(4, 5)], 2)
266+
@test x == [fill("Q1: [1.0, 2.5)", 5); fill("Q2: [2.5, 4.0]", 5)]
267+
@test levels(x) == ["Q1: [1.0, 2.5)", "Q2: [2.5, 4.0]"]
268+
@test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3)
269+
x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true)
270+
@test x == [fill("Q2: [1.0, 4.0)", 5); fill("Q3: [4.0, 4.0]", 5)]
271+
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: [1.0, 4.0)", "Q3: [4.0, 4.0]"]
272+
end
273+
274+
@testset "cut with -0.0" begin
275+
x = cut([-0.0, 0.0, 0.0, -0.0], 2)
276+
@test x == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]", "Q2: [0.0, 0.0]", "Q1: [-0.0, 0.0)"]
277+
@test levels(x) == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]"]
278+
279+
x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0])
280+
@test x == ["[-0.0, 0.0)", "[0.0, 0.0]", "[0.0, 0.0]", "[-0.0, 0.0)"]
281+
@test levels(x) == ["[-0.0, 0.0)", "[0.0, 0.0]"]
282+
283+
x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0])
284+
@test x == fill("[-0.0, 0.0]", 4)
285+
@test levels(x) == ["[-0.0, 0.0]"]
286+
287+
x = cut([-0.0, 0.0, 0.0, -0.0], [0.0], extend=true)
288+
@test x == fill("[-0.0, 0.0]", 4)
289+
@test levels(x) == ["[-0.0, 0.0]"]
290+
291+
x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0], extend=true)
292+
@test x == fill("[-0.0, 0.0]", 4)
293+
@test levels(x) == ["[-0.0, 0.0]"]
294+
295+
x = cut([-0.0, 0.0, 0.0, -0.0], 2, labels=[-0.0, 0.0])
296+
@test x == [-0.0, 0.0, 0.0, -0.0]
297+
298+
@test_throws ArgumentError cut([-0.0, 0.0, 0.0, -0.0], [-0.0, -0.0, 0.0])
245299
end
246300

247301
@testset "cut with extend=true" begin
@@ -276,4 +330,35 @@ end
276330
@test x == ["[-1.0, 0.0)", "[-1.0, 0.0)", "[0.0, 1.0]", "[0.0, 1.0]", "[0.0, 1.0]"]
277331
end
278332

333+
@testset "cut with NaN and Inf" begin
334+
@test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], [1, 10])
335+
@test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], [1], extend=true)
336+
@test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], 2)
337+
@test_throws ArgumentError("NaN values are not allowed in breaks") cut([1, 2], [1, NaN])
338+
339+
x = cut([1, Inf], [1], extend=true)
340+
@test x ["[1.0, Inf]", "[1.0, Inf]"]
341+
@test levels(x) == ["[1.0, Inf]"]
342+
343+
x = cut([1, -Inf], [1], extend=true)
344+
@test x ["[-Inf, 1.0]", "[-Inf, 1.0]"]
345+
@test levels(x) == ["[-Inf, 1.0]"]
346+
347+
x = cut([1:5; Inf], [1, 2, Inf])
348+
@test x ["[1.0, 2.0)"; fill("[2.0, Inf]", 5)]
349+
@test levels(x) == ["[1.0, 2.0)", "[2.0, Inf]"]
350+
351+
x = cut([1:5; -Inf], [-Inf, 2, 5])
352+
@test x ["[-Inf, 2.0)"; fill("[2.0, 5.0]", 4); "[-Inf, 2.0)"]
353+
@test levels(x) == ["[-Inf, 2.0)", "[2.0, 5.0]"]
354+
355+
x = cut([1:5; Inf], 2)
356+
@test x [fill("Q1: [1.0, 3.5)", 3); fill("Q2: [3.5, Inf]", 3)]
357+
@test levels(x) == ["Q1: [1.0, 3.5)", "Q2: [3.5, Inf]"]
358+
359+
x = cut([1:5; -Inf], 2)
360+
@test x [fill("Q1: [-Inf, 2.5)", 2); fill("Q2: [2.5, 5.0]", 3); "Q1: [-Inf, 2.5)"]
361+
@test levels(x) == ["Q1: [-Inf, 2.5)", "Q2: [2.5, 5.0]"]
279362
end
363+
364+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ module TestCategoricalArrays
1010
using Test
1111
using CategoricalArrays
1212

13+
const = isequal
14+
1315
tests = [
1416
"01_value.jl",
1517
"04_constructors.jl",

0 commit comments

Comments
 (0)