Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ default_formatter(from, to, i; leftclosed, rightclosed) =

Cut a numeric array into intervals at values `breaks`
and return an ordered `CategoricalArray` indicating
the interval into which each entry falls. Intervals are of the form `[lower, upper)`,
i.e. the lower bound is included and the upper bound is excluded, except
the interval into which each entry falls. Intervals are of the form `[lower, upper)`
(closed on the left), i.e. the lower bound is included and the upper bound is excluded, except
the last interval, which is closed on both ends, i.e. `[lower, upper]`.

If `x` accepts missing values (i.e. `eltype(x) >: Missing`) the returned array will
Expand Down Expand Up @@ -81,15 +81,15 @@ julia> cut(-1:0.5:1, 2)
"Q1: [-1.0, 0.0)"
"Q2: [0.0, 1.0]"
"Q2: [0.0, 1.0]"
"Q2: [0.0, 1.0]"
"Q2: [0.0, 1.0]"

julia> cut(-1:0.5:1, 2, labels=["A", "B"])
5-element CategoricalArray{String,1,UInt32}:
"A"
"A"
"B"
"B"
"B"
"B"

julia> cut(-1:0.5:1, 2, labels=[-0.5, +0.5])
5-element CategoricalArray{Float64,1,UInt32}:
Expand All @@ -104,11 +104,11 @@ fmt (generic function with 1 method)

julia> cut(-1:0.5:1, 3, labels=fmt)
5-element CategoricalArray{String,1,UInt32}:
"grp 1 (-1.0//-0.3333333333333335)"
"grp 1 (-1.0//-0.3333333333333335)"
"grp 2 (-0.3333333333333335//0.33333333333333326)"
"grp 3 (0.33333333333333326//1.0)"
"grp 3 (0.33333333333333326//1.0)"
"grp 1 (-1.0//0.0)"
"grp 1 (-1.0//0.0)"
"grp 2 (0.0//0.5)"
"grp 3 (0.5//1.0)"
"grp 3 (0.5//1.0)"
```
"""
@inline function cut(x::AbstractArray, breaks::AbstractVector;
Expand Down Expand Up @@ -233,12 +233,38 @@ Provide the default label format for the `cut(x, ngroups)` method.
quantile_formatter(from, to, i; leftclosed, rightclosed) =
string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")

"""
Find first value in (sorted) `v` which is greater than or equal to each quantile
in (sorted) `qs`.
"""
function find_breaks(v::AbstractVector, qs::AbstractVector)
n = length(qs)
breaks = similar(v, n)
n == 0 && return breaks

i = 1
q = qs[1]
@inbounds for x in v
# Use isless and isequal to differentiate -0.0 from 0.0
if isless(q, x) || isequal(q, x)
breaks[i] = x
i += 1
i > n && break
q = qs[i]
end
end
return breaks
end

"""
cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:AbstractString},Function},
allowempty::Bool=false)

Cut a numeric array into `ngroups` quantiles, determined using `quantile`.
Cut a numeric array into `ngroups` quantiles.

This is equivalent to `cut(x, quantile(x, (0:ngroups)/ngroups))`,
but breaks are taken from actual data values instead of estimated quantiles.

If `x` contains `missing` values, they are automatically skipped when computing
quantiles.
Expand All @@ -258,15 +284,14 @@ function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
allowempty::Bool=false)
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
xnm = eltype(x) >: Missing ? skipmissing(x) : x
# Computing extrema is faster than taking 0 and 1 quantiles
min_x, max_x = extrema(xnm)
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
end
breaks = quantile(xnm, (1:ngroups-1)/ngroups)
breaks = [min_x; breaks; max_x]
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
if !allowempty && !allunique(@view breaks[1:end-1])
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
"too many duplicated values in `x`. " *
Expand Down
50 changes: 34 additions & 16 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,18 @@ end

@testset "cut([5, 4, 3, 2], 2)" begin
x = @inferred cut([5, 4, 3, 2], 2)
@test x == ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", "Q1: [2.0, 3.5)"]
@test x == ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", "Q1: [2, 4)"]
@test isa(x, CategoricalArray)
@test isordered(x)
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
@test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"]
end

@testset "cut(x, n) with missing values" begin
x = @inferred cut([5, 4, 3, missing, 2], 2)
@test x ≅ ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", missing, "Q1: [2.0, 3.5)"]
@test x ≅ ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", missing, "Q1: [2, 4)"]
@test isa(x, CategoricalArray)
@test isordered(x)
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
@test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"]
end

@testset "cut(x, n) with invalid n" begin
Expand Down Expand Up @@ -255,20 +255,29 @@ end
@test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt)

@test_throws ArgumentError cut([fill(1, 10); 4], 2)
@test_throws ArgumentError cut([fill(1, 10); 4], 3)
x = cut([fill(1, 10); 4], 2, allowempty=true)
@test unique(x) == ["Q2: [1.0, 4.0]"]
@test unique(x) == ["Q2: [1, 4]"]
@test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4]"]
@test_throws ArgumentError cut([fill(1, 10); 4], 3)
x = cut([fill(1, 10); 4], 3, allowempty=true)
@test unique(x) == ["Q3: [1.0, 4.0]"]
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: (1.0, 1.0)", "Q3: [1.0, 4.0]"]
@test unique(x) == ["Q3: [1, 4]"]
@test levels(x) == ["Q1: (1, 1)", "Q2: (1, 1)", "Q3: [1, 4]"]

x = cut([fill(4, 10); 1], 2)
@test x == [fill("Q2: [4, 4]", 10); "Q1: [1, 4)"]
@test levels(x) == ["Q1: [1, 4)"; "Q2: [4, 4]"]
@test_throws ArgumentError cut([fill(4, 10); 1], 3)
x = cut([fill(4, 10); 1], 3, allowempty=true)
@test x == [fill("Q3: [4, 4]", 10); "Q1: [1, 4)"]
@test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"]

x = cut([fill(1, 5); fill(4, 5)], 2)
@test x == [fill("Q1: [1.0, 2.5)", 5); fill("Q2: [2.5, 4.0]", 5)]
@test levels(x) == ["Q1: [1.0, 2.5)", "Q2: [2.5, 4.0]"]
@test x == [fill("Q1: [1, 4)", 5); fill("Q2: [4, 4]", 5)]
@test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"]
@test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3)
x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true)
@test x == [fill("Q2: [1.0, 4.0)", 5); fill("Q3: [4.0, 4.0]", 5)]
@test levels(x) == ["Q1: (1.0, 1.0)", "Q2: [1.0, 4.0)", "Q3: [4.0, 4.0]"]
@test x == [fill("Q2: [1, 4)", 5); fill("Q3: [4, 4]", 5)]
@test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4)", "Q3: [4, 4]"]
end

@testset "cut with -0.0" begin
Expand Down Expand Up @@ -353,12 +362,21 @@ end
@test levels(x) == ["[-Inf, 2.0)", "[2.0, 5.0]"]

x = cut([1:5; Inf], 2)
@test x ≅ [fill("Q1: [1.0, 3.5)", 3); fill("Q2: [3.5, Inf]", 3)]
@test levels(x) == ["Q1: [1.0, 3.5)", "Q2: [3.5, Inf]"]
@test x ≅ [fill("Q1: [1.0, 4.0)", 3); fill("Q2: [4.0, Inf]", 3)]
@test levels(x) == ["Q1: [1.0, 4.0)", "Q2: [4.0, Inf]"]

x = cut([1:5; -Inf], 2)
@test x ≅ [fill("Q1: [-Inf, 2.5)", 2); fill("Q2: [2.5, 5.0]", 3); "Q1: [-Inf, 2.5)"]
@test levels(x) == ["Q1: [-Inf, 2.5)", "Q2: [2.5, 5.0]"]
@test x ≅ [fill("Q1: [-Inf, 3.0)", 2); fill("Q2: [3.0, 5.0]", 3); "Q1: [-Inf, 3.0)"]
@test levels(x) == ["Q1: [-Inf, 3.0)", "Q2: [3.0, 5.0]"]
end

@testset "cut when quantile falls exactly on a data value" begin
x = cut([11, 14, 43, 54, 54, 56, 73, 79, 84, 84], 3)
@test x ==
["Q1: [11, 54)", "Q1: [11, 54)", "Q1: [11, 54)",
"Q2: [54, 73)", "Q2: [54, 73)", "Q2: [54, 73)",
"Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]"]
@test levels(x) == ["Q1: [11, 54)", "Q2: [54, 73)", "Q3: [73, 84]"]
end

end
Loading