From e0fe39c14a3ebca92c87b010ee9f01220dc3feb3 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 22 Feb 2025 12:10:12 +0100 Subject: [PATCH 01/12] Choose different quantile cutpoints in `cut(x, n)` `Statistics.quantile` returns values which are not the most appropriate to generate labels. It is more intuitive to choose values from the actual data, which are likely to have fewer decimals and make more sense for users. Unfortunately, since we use intervals closed on the left, we cannot use any of the seven standard definitions of quantiles. Type 1 is the closest, but we have to take the value next to it as a cutpoint to prevent it from being included into the next quantile group. This gives essentially consistent group attributions to R's `Hmisc::cut2` or `cut(x, quantile(x, (0:n)/n, type=1), include.lowest=T))`, though with different cutpoints in labels. --- src/extras.jl | 24 +++++++++++++++++++----- test/15_extras.jl | 30 +++++++++++++++--------------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/extras.jl b/src/extras.jl index 2afcef38..7b7d5d0c 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -42,8 +42,8 @@ default_formatter(from, to, i; leftclosed, rightclosed) = Cut a numeric array into intervals at values `breaks` and return an ordered `CategoricalArray` indicating -the interval into which each entry falls. Intervals are of the form `[lower, upper)`, -i.e. the lower bound is included and the upper bound is excluded, except +the interval into which each entry falls. Intervals are of the form `[lower, upper)` +(closed on the left), i.e. the lower bound is included and the upper bound is excluded, except the last interval, which is closed on both ends, i.e. `[lower, upper]`. If `x` accepts missing values (i.e. `eltype(x) >: Missing`) the returned array will @@ -233,12 +233,27 @@ Provide the default label format for the `cut(x, ngroups)` method. quantile_formatter(from, to, i; leftclosed, rightclosed) = string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")") +function _quantile!(v::AbstractVector, p::AbstractVector) + n = length(v) + n > 0 || throw(ArgumentError("cannot compute quantiles of empty data vector")) + sort!(v) + return map(p) do i + v[clamp(ceil(Int, n*i), 0, n-1) + firstindex(v)] + end +end +_quantile(x::AbstractArray, p::AbstractVector) = + _quantile!(Base.copymutable(vec(x)), p) +_quantile(x, p::AbstractVector) = _quantile!(collect(x), p) + """ cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:AbstractString},Function}, allowempty::Bool=false) -Cut a numeric array into `ngroups` quantiles, determined using `quantile`. +Cut a numeric array into `ngroups` quantiles. + +Cutpoints differ from those returned by `Statistics.quantile` as they are suited +for intervals closed on the left and taken from actual values in `x`. If `x` contains `missing` values, they are automatically skipped when computing quantiles. @@ -265,8 +280,7 @@ function cut(x::AbstractArray, ngroups::Integer; (max_x isa Number && isnan(max_x)) throw(ArgumentError("NaN values are not allowed in input vector")) end - breaks = quantile(xnm, (1:ngroups-1)/ngroups) - breaks = [min_x; breaks; max_x] + breaks = _quantile(xnm, (0:ngroups)/ngroups) if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * "too many duplicated values in `x`. " * diff --git a/test/15_extras.jl b/test/15_extras.jl index 1aaf8dc7..8440141c 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -127,18 +127,18 @@ end @testset "cut([5, 4, 3, 2], 2)" begin x = @inferred cut([5, 4, 3, 2], 2) - @test x == ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", "Q1: [2.0, 3.5)"] + @test x == ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", "Q1: [2, 4)"] @test isa(x, CategoricalArray) @test isordered(x) - @test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"] + @test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"] end @testset "cut(x, n) with missing values" begin x = @inferred cut([5, 4, 3, missing, 2], 2) - @test x ≅ ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", missing, "Q1: [2.0, 3.5)"] + @test x ≅ ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", missing, "Q1: [2, 4)"] @test isa(x, CategoricalArray) @test isordered(x) - @test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"] + @test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"] end @testset "cut(x, n) with invalid n" begin @@ -257,18 +257,18 @@ end @test_throws ArgumentError cut([fill(1, 10); 4], 2) @test_throws ArgumentError cut([fill(1, 10); 4], 3) x = cut([fill(1, 10); 4], 2, allowempty=true) - @test unique(x) == ["Q2: [1.0, 4.0]"] + @test unique(x) == ["Q2: [1, 4]"] x = cut([fill(1, 10); 4], 3, allowempty=true) - @test unique(x) == ["Q3: [1.0, 4.0]"] - @test levels(x) == ["Q1: (1.0, 1.0)", "Q2: (1.0, 1.0)", "Q3: [1.0, 4.0]"] + @test unique(x) == ["Q3: [1, 4]"] + @test levels(x) == ["Q1: (1, 1)", "Q2: (1, 1)", "Q3: [1, 4]"] x = cut([fill(1, 5); fill(4, 5)], 2) - @test x == [fill("Q1: [1.0, 2.5)", 5); fill("Q2: [2.5, 4.0]", 5)] - @test levels(x) == ["Q1: [1.0, 2.5)", "Q2: [2.5, 4.0]"] + @test x == [fill("Q1: [1, 4)", 5); fill("Q2: [4, 4]", 5)] + @test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"] @test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3) x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true) - @test x == [fill("Q2: [1.0, 4.0)", 5); fill("Q3: [4.0, 4.0]", 5)] - @test levels(x) == ["Q1: (1.0, 1.0)", "Q2: [1.0, 4.0)", "Q3: [4.0, 4.0]"] + @test x == [fill("Q2: [1, 4)", 5); fill("Q3: [4, 4]", 5)] + @test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4)", "Q3: [4, 4]"] end @testset "cut with -0.0" begin @@ -353,12 +353,12 @@ end @test levels(x) == ["[-Inf, 2.0)", "[2.0, 5.0]"] x = cut([1:5; Inf], 2) - @test x ≅ [fill("Q1: [1.0, 3.5)", 3); fill("Q2: [3.5, Inf]", 3)] - @test levels(x) == ["Q1: [1.0, 3.5)", "Q2: [3.5, Inf]"] + @test x ≅ [fill("Q1: [1.0, 4.0)", 3); fill("Q2: [4.0, Inf]", 3)] + @test levels(x) == ["Q1: [1.0, 4.0)", "Q2: [4.0, Inf]"] x = cut([1:5; -Inf], 2) - @test x ≅ [fill("Q1: [-Inf, 2.5)", 2); fill("Q2: [2.5, 5.0]", 3); "Q1: [-Inf, 2.5)"] - @test levels(x) == ["Q1: [-Inf, 2.5)", "Q2: [2.5, 5.0]"] + @test x ≅ [fill("Q1: [-Inf, 3.0)", 2); fill("Q2: [3.0, 5.0]", 3); "Q1: [-Inf, 3.0)"] + @test levels(x) == ["Q1: [-Inf, 3.0)", "Q2: [3.0, 5.0]"] end end \ No newline at end of file From bbed3a272c8460ecbde5aebc4db1ad11cb4bf9ef Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 18 Apr 2025 12:04:43 +0200 Subject: [PATCH 02/12] Improve choice of cutpoints --- src/extras.jl | 54 ++++++++++++++++++++++++++--------------------- test/15_extras.jl | 14 ++++++------ 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/extras.jl b/src/extras.jl index 7b7d5d0c..3b0b7b66 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -77,17 +77,17 @@ julia> cut(-1:0.5:1, [0, 1], extend=true) julia> cut(-1:0.5:1, 2) 5-element CategoricalArray{String,1,UInt32}: - "Q1: [-1.0, 0.0)" - "Q1: [-1.0, 0.0)" - "Q2: [0.0, 1.0]" - "Q2: [0.0, 1.0]" - "Q2: [0.0, 1.0]" + "Q1: [-1.0, 0.5)" + "Q1: [-1.0, 0.5)" + "Q1: [-1.0, 0.5)" + "Q2: [0.5, 1.0]" + "Q2: [0.5, 1.0]" julia> cut(-1:0.5:1, 2, labels=["A", "B"]) 5-element CategoricalArray{String,1,UInt32}: "A" "A" - "B" + "A" "B" "B" @@ -95,7 +95,7 @@ julia> cut(-1:0.5:1, 2, labels=[-0.5, +0.5]) 5-element CategoricalArray{Float64,1,UInt32}: -0.5 -0.5 - 0.5 + -0.5 0.5 0.5 @@ -104,11 +104,11 @@ fmt (generic function with 1 method) julia> cut(-1:0.5:1, 3, labels=fmt) 5-element CategoricalArray{String,1,UInt32}: - "grp 1 (-1.0//-0.3333333333333335)" - "grp 1 (-1.0//-0.3333333333333335)" - "grp 2 (-0.3333333333333335//0.33333333333333326)" - "grp 3 (0.33333333333333326//1.0)" - "grp 3 (0.33333333333333326//1.0)" + "grp 1 (-1.0//0.0)" + "grp 1 (-1.0//0.0)" + "grp 2 (0.0//1.0)" + "grp 2 (0.0//1.0)" + "grp 3 (1.0//1.0)" ``` """ @inline function cut(x::AbstractArray, breaks::AbstractVector; @@ -233,17 +233,21 @@ Provide the default label format for the `cut(x, ngroups)` method. quantile_formatter(from, to, i; leftclosed, rightclosed) = string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")") -function _quantile!(v::AbstractVector, p::AbstractVector) +function _quantile!(v::AbstractVector, ps::AbstractVector) n = length(v) n > 0 || throw(ArgumentError("cannot compute quantiles of empty data vector")) - sort!(v) - return map(p) do i - v[clamp(ceil(Int, n*i), 0, n-1) + firstindex(v)] + return map(ps) do p + i = clamp(ceil(Int, n*p), 1, n) + firstindex(v) - 1 + q = v[i] + # Take next distinct value even if quantile falls in a series of duplicated values + @inbounds for j in (i+1):lastindex(v) + q_prev = q + q = v[j] + q_prev != q && break + end + return q end end -_quantile(x::AbstractArray, p::AbstractVector) = - _quantile!(Base.copymutable(vec(x)), p) -_quantile(x, p::AbstractVector) = _quantile!(collect(x), p) """ cut(x::AbstractArray, ngroups::Integer; @@ -253,7 +257,9 @@ _quantile(x, p::AbstractVector) = _quantile!(collect(x), p) Cut a numeric array into `ngroups` quantiles. Cutpoints differ from those returned by `Statistics.quantile` as they are suited -for intervals closed on the left and taken from actual values in `x`. +for intervals closed on the left and taken from actual values in `x`. However, +group assignments are identical to those which would be obtained with type 1 +quantiles if intervals were closed on the right. If `x` contains `missing` values, they are automatically skipped when computing quantiles. @@ -273,14 +279,14 @@ function cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter, allowempty::Bool=false) ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)")) - xnm = eltype(x) >: Missing ? skipmissing(x) : x - # Computing extrema is faster than taking 0 and 1 quantiles - min_x, max_x = extrema(xnm) + xnm = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) + min_x, max_x = first(xnm), last(xnm) if (min_x isa Number && isnan(min_x)) || (max_x isa Number && isnan(max_x)) throw(ArgumentError("NaN values are not allowed in input vector")) end - breaks = _quantile(xnm, (0:ngroups)/ngroups) + qs = _quantile!(xnm, (1:(ngroups-1))/ngroups) + breaks = [min_x; qs; max_x] if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * "too many duplicated values in `x`. " * diff --git a/test/15_extras.jl b/test/15_extras.jl index 8440141c..0a686464 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -254,21 +254,21 @@ end fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0) @test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt) - @test_throws ArgumentError cut([fill(1, 10); 4], 2) + x = cut([fill(1, 10); 4], 2) + @test x == [fill("Q1: [1, 4)", 10); "Q2: [4, 4]"] + @test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"] @test_throws ArgumentError cut([fill(1, 10); 4], 3) - x = cut([fill(1, 10); 4], 2, allowempty=true) - @test unique(x) == ["Q2: [1, 4]"] x = cut([fill(1, 10); 4], 3, allowempty=true) - @test unique(x) == ["Q3: [1, 4]"] - @test levels(x) == ["Q1: (1, 1)", "Q2: (1, 1)", "Q3: [1, 4]"] + @test x == [fill("Q1: [1, 4)", 10); "Q3: [4, 4]"] + @test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"] x = cut([fill(1, 5); fill(4, 5)], 2) @test x == [fill("Q1: [1, 4)", 5); fill("Q2: [4, 4]", 5)] @test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"] @test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3) x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true) - @test x == [fill("Q2: [1, 4)", 5); fill("Q3: [4, 4]", 5)] - @test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4)", "Q3: [4, 4]"] + @test x == [fill("Q1: [1, 4)", 5); fill("Q3: [4, 4]", 5)] + @test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"] end @testset "cut with -0.0" begin From 9dd935bb831174d5bfd988a2961b2d2bf0091c47 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 27 Apr 2025 21:46:29 +0200 Subject: [PATCH 03/12] WIP --- src/extras.jl | 42 ++++++++++++++++++++++++++---------------- test/15_extras.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 16 deletions(-) diff --git a/src/extras.jl b/src/extras.jl index 3b0b7b66..f3bc17ca 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -233,20 +233,28 @@ Provide the default label format for the `cut(x, ngroups)` method. quantile_formatter(from, to, i; leftclosed, rightclosed) = string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")") -function _quantile!(v::AbstractVector, ps::AbstractVector) - n = length(v) - n > 0 || throw(ArgumentError("cannot compute quantiles of empty data vector")) - return map(ps) do p - i = clamp(ceil(Int, n*p), 1, n) + firstindex(v) - 1 - q = v[i] - # Take next distinct value even if quantile falls in a series of duplicated values - @inbounds for j in (i+1):lastindex(v) - q_prev = q - q = v[j] - q_prev != q && break +"""Find first value in data which is greater than each quantile in ``qs``.""" +function find_breaks(v::AbstractVector, qs::AbstractVector) + n = length(qs) + breaks = similar(v, n) + n == 0 && return breaks + + i = 1 + q = qs[1] + @inbounds for x in v + if x > q + breaks[i] = x + i += 1 + i > n && break + q = qs[i] end - return q end + # if last values in x are equal to q, breaks were not initialized + for i in i:n + breaks[i] = q + end + @show breaks + return breaks end """ @@ -279,14 +287,16 @@ function cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter, allowempty::Bool=false) ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)")) - xnm = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) - min_x, max_x = first(xnm), last(xnm) + sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) + min_x, max_x = first(sorted_x), last(sorted_x) if (min_x isa Number && isnan(min_x)) || (max_x isa Number && isnan(max_x)) throw(ArgumentError("NaN values are not allowed in input vector")) end - qs = _quantile!(xnm, (1:(ngroups-1))/ngroups) - breaks = [min_x; qs; max_x] + qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) + @show qs, min_x, max_x + breaks = [min_x; find_breaks(sorted_x, qs); max_x] + @show breaks if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * "too many duplicated values in `x`. " * diff --git a/test/15_extras.jl b/test/15_extras.jl index 0a686464..859cd349 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -361,4 +361,51 @@ end @test levels(x) == ["Q1: [-Inf, 3.0)", "Q2: [3.0, 5.0]"] end +@testset "cut in corner cases" begin + # In this case, cut(x, quantile(x, (0:36)/36)) in R generates + # an empty "(143,172]" level + # and qcut(x, 36) in Polars misses that level. + # Our approach uses different breaks at 143 and 182 + x = [23, 23, 60, 76, 84, 95, 101, 108, 111, 133, 137, 143, 143, 143, 182, + 206, 214, 241, 258, 262, 280, 289, 303, 312, 321, 323, 352, 353, 354, + 368, 369, 373, 374, 384, 385, 386, 387, 392, 405, 406, 410, 421, 430, + 430, 431, 442, 464, 474, 478, 479, 496, 516, 530, 534, 549, 554, 568, + 575, 589, 590, 591, 592, 595, 596, 603, 625, 632, 632, 638, 640, 640, + 645, 648, 690, 704, 748, 758, 771, 772, 803, 835, 839, 853, 869, 873, + 874, 887, 911, 920, 923, 928, 933, 943, 945, 945, 947, 951, 965, 978, 980] + + @test cut(x, 36) == + ["Q17: [442, 474)", "Q8: [280, 312)", "Q2: [76, 101)", "Q9: [312, 323)", + "Q14: [387, 406)", "Q30: [835, 869)", "Q17: [442, 474)", "Q35: [947, 965)", + "Q2: [76, 101)", "Q11: [354, 373)", "Q32: [887, 923)", "Q12: [373, 385)", + "Q24: [603, 638)", "Q29: [772, 835)", "Q24: [603, 638)", "Q15: [406, 430)", + "Q11: [354, 373)", "Q23: [592, 603)", "Q3: [101, 133)", "Q16: [430, 442)", + "Q34: [933, 947)", "Q27: [648, 748)", "Q28: [748, 772)", "Q28: [748, 772)", + "Q21: [568, 589)", "Q18: [474, 496)", "Q32: [887, 923)", "Q11: [354, 373)", + "Q7: [241, 280)", "Q3: [101, 133)", "Q19: [496, 534)", "Q13: [385, 387)", + "Q36: [965, 980]", "Q33: [923, 933)", "Q16: [430, 442)", "Q36: [965, 980]", + "Q27: [648, 748)", "Q24: [603, 638)", "Q32: [887, 923)", "Q4: [133, 182)", + "Q22: [589, 592)", "Q1: [23, 76)", "Q5: [182, 206)", "Q28: [748, 772)", + "Q30: [835, 869)", "Q31: [869, 887)", "Q22: [589, 592)", "Q15: [406, 430)", + "Q31: [869, 887)", "Q19: [496, 534)", "Q26: [640, 648)", "Q8: [280, 312)", + "Q18: [474, 496)", "Q14: [387, 406)", "Q7: [241, 280)", "Q30: [835, 869)", + "Q3: [101, 133)", "Q9: [312, 323)", "Q4: [133, 182)", "Q24: [603, 638)", + "Q20: [534, 568)", "Q25: [638, 640)", "Q20: [534, 568)", "Q23: [592, 603)", + "Q12: [373, 385)", "Q27: [648, 748)", "Q29: [772, 835)", "Q6: [206, 241)", + "Q34: [933, 947)", "Q16: [430, 442)", "Q26: [640, 648)", "Q15: [406, 430)", + "Q12: [373, 385)", "Q26: [640, 648)", "Q19: [496, 534)", "Q4: [133, 182)", + "Q34: [933, 947)", "Q31: [869, 887)", "Q10: [323, 354)", "Q21: [568, 589)", + "Q4: [133, 182)", "Q7: [241, 280)", "Q35: [947, 965)", "Q14: [387, 406)", + "Q18: [474, 496)", "Q34: [933, 947)", "Q20: [534, 568)", "Q22: [589, 592)", + "Q33: [923, 933)", "Q10: [323, 354)", "Q13: [385, 387)", "Q1: [23, 76)", + "Q36: [965, 980]", "Q2: [76, 101)", "Q23: [592, 603)", "Q6: [206, 241)", + "Q1: [23, 76)", "Q10: [323, 354)", "Q4: [133, 182)", "Q8: [280, 312)"] + + @test cut([0, 1, 1, 1, 1], 2) == + ["Q1: [0, 1)", "Q2: [1, 1]", "Q2: [1, 1]", "Q2: [1, 1]", "Q2: [1, 1]"] + + @test cut([1, 1, 1, 1, 2], 2) == + ["Q1: [1, 2)", "Q1: [1, 2)", "Q1: [1, 2)", "Q1: [1, 2)", "Q2: [2, 2]"] +end + end \ No newline at end of file From cac54b298296d1cbadccb5c4a96838bbb739fbb3 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 27 Apr 2025 23:06:39 +0200 Subject: [PATCH 04/12] Yet another approach --- src/extras.jl | 15 +++++----- test/15_extras.jl | 73 ++++++++++++++--------------------------------- 2 files changed, 30 insertions(+), 58 deletions(-) diff --git a/src/extras.jl b/src/extras.jl index f3bc17ca..1c5f3596 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -233,7 +233,10 @@ Provide the default label format for the `cut(x, ngroups)` method. quantile_formatter(from, to, i; leftclosed, rightclosed) = string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")") -"""Find first value in data which is greater than each quantile in ``qs``.""" +""" +Find first value in (sorted) `v` which is greater than or equal to each quantile +in (sorted) `qs`. +""" function find_breaks(v::AbstractVector, qs::AbstractVector) n = length(qs) breaks = similar(v, n) @@ -242,7 +245,8 @@ function find_breaks(v::AbstractVector, qs::AbstractVector) i = 1 q = qs[1] @inbounds for x in v - if x > q + # Use isless and isequal to differentiate -0.0 from 0.0 + if isless(q, x) || isequal(q, x) breaks[i] = x i += 1 i > n && break @@ -253,7 +257,6 @@ function find_breaks(v::AbstractVector, qs::AbstractVector) for i in i:n breaks[i] = q end - @show breaks return breaks end @@ -264,10 +267,8 @@ end Cut a numeric array into `ngroups` quantiles. -Cutpoints differ from those returned by `Statistics.quantile` as they are suited -for intervals closed on the left and taken from actual values in `x`. However, -group assignments are identical to those which would be obtained with type 1 -quantiles if intervals were closed on the right. +This is equivalent to `cut(x, quantile(x, (0:ngroups)/ngroups))`, +but breaks are taken from actual data values instead of estimated quantiles. If `x` contains `missing` values, they are automatically skipped when computing quantiles. diff --git a/test/15_extras.jl b/test/15_extras.jl index 859cd349..b995e3bf 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -254,12 +254,21 @@ end fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0) @test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt) - x = cut([fill(1, 10); 4], 2) - @test x == [fill("Q1: [1, 4)", 10); "Q2: [4, 4]"] - @test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"] + @test_throws ArgumentError cut([fill(1, 10); 4], 2) + x = cut([fill(1, 10); 4], 2, allowempty=true) + @test unique(x) == ["Q2: [1, 4]"] + @test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4]"] @test_throws ArgumentError cut([fill(1, 10); 4], 3) x = cut([fill(1, 10); 4], 3, allowempty=true) - @test x == [fill("Q1: [1, 4)", 10); "Q3: [4, 4]"] + @test unique(x) == ["Q3: [1, 4]"] + @test levels(x) == ["Q1: (1, 1)", "Q2: (1, 1)", "Q3: [1, 4]"] + + x = cut([fill(4, 10); 1], 2) + @test x == [fill("Q2: [4, 4]", 10); "Q1: [1, 4)"] + @test levels(x) == ["Q1: [1, 4)"; "Q2: [4, 4]"] + @test_throws ArgumentError cut([fill(4, 10); 1], 3) + x = cut([fill(4, 10); 1], 3, allowempty=true) + @test x == [fill("Q3: [4, 4]", 10); "Q1: [1, 4)"] @test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"] x = cut([fill(1, 5); fill(4, 5)], 2) @@ -267,8 +276,8 @@ end @test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"] @test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3) x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true) - @test x == [fill("Q1: [1, 4)", 5); fill("Q3: [4, 4]", 5)] - @test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"] + @test x == [fill("Q2: [1, 4)", 5); fill("Q3: [4, 4]", 5)] + @test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4)", "Q3: [4, 4]"] end @testset "cut with -0.0" begin @@ -361,51 +370,13 @@ end @test levels(x) == ["Q1: [-Inf, 3.0)", "Q2: [3.0, 5.0]"] end -@testset "cut in corner cases" begin - # In this case, cut(x, quantile(x, (0:36)/36)) in R generates - # an empty "(143,172]" level - # and qcut(x, 36) in Polars misses that level. - # Our approach uses different breaks at 143 and 182 - x = [23, 23, 60, 76, 84, 95, 101, 108, 111, 133, 137, 143, 143, 143, 182, - 206, 214, 241, 258, 262, 280, 289, 303, 312, 321, 323, 352, 353, 354, - 368, 369, 373, 374, 384, 385, 386, 387, 392, 405, 406, 410, 421, 430, - 430, 431, 442, 464, 474, 478, 479, 496, 516, 530, 534, 549, 554, 568, - 575, 589, 590, 591, 592, 595, 596, 603, 625, 632, 632, 638, 640, 640, - 645, 648, 690, 704, 748, 758, 771, 772, 803, 835, 839, 853, 869, 873, - 874, 887, 911, 920, 923, 928, 933, 943, 945, 945, 947, 951, 965, 978, 980] - - @test cut(x, 36) == - ["Q17: [442, 474)", "Q8: [280, 312)", "Q2: [76, 101)", "Q9: [312, 323)", - "Q14: [387, 406)", "Q30: [835, 869)", "Q17: [442, 474)", "Q35: [947, 965)", - "Q2: [76, 101)", "Q11: [354, 373)", "Q32: [887, 923)", "Q12: [373, 385)", - "Q24: [603, 638)", "Q29: [772, 835)", "Q24: [603, 638)", "Q15: [406, 430)", - "Q11: [354, 373)", "Q23: [592, 603)", "Q3: [101, 133)", "Q16: [430, 442)", - "Q34: [933, 947)", "Q27: [648, 748)", "Q28: [748, 772)", "Q28: [748, 772)", - "Q21: [568, 589)", "Q18: [474, 496)", "Q32: [887, 923)", "Q11: [354, 373)", - "Q7: [241, 280)", "Q3: [101, 133)", "Q19: [496, 534)", "Q13: [385, 387)", - "Q36: [965, 980]", "Q33: [923, 933)", "Q16: [430, 442)", "Q36: [965, 980]", - "Q27: [648, 748)", "Q24: [603, 638)", "Q32: [887, 923)", "Q4: [133, 182)", - "Q22: [589, 592)", "Q1: [23, 76)", "Q5: [182, 206)", "Q28: [748, 772)", - "Q30: [835, 869)", "Q31: [869, 887)", "Q22: [589, 592)", "Q15: [406, 430)", - "Q31: [869, 887)", "Q19: [496, 534)", "Q26: [640, 648)", "Q8: [280, 312)", - "Q18: [474, 496)", "Q14: [387, 406)", "Q7: [241, 280)", "Q30: [835, 869)", - "Q3: [101, 133)", "Q9: [312, 323)", "Q4: [133, 182)", "Q24: [603, 638)", - "Q20: [534, 568)", "Q25: [638, 640)", "Q20: [534, 568)", "Q23: [592, 603)", - "Q12: [373, 385)", "Q27: [648, 748)", "Q29: [772, 835)", "Q6: [206, 241)", - "Q34: [933, 947)", "Q16: [430, 442)", "Q26: [640, 648)", "Q15: [406, 430)", - "Q12: [373, 385)", "Q26: [640, 648)", "Q19: [496, 534)", "Q4: [133, 182)", - "Q34: [933, 947)", "Q31: [869, 887)", "Q10: [323, 354)", "Q21: [568, 589)", - "Q4: [133, 182)", "Q7: [241, 280)", "Q35: [947, 965)", "Q14: [387, 406)", - "Q18: [474, 496)", "Q34: [933, 947)", "Q20: [534, 568)", "Q22: [589, 592)", - "Q33: [923, 933)", "Q10: [323, 354)", "Q13: [385, 387)", "Q1: [23, 76)", - "Q36: [965, 980]", "Q2: [76, 101)", "Q23: [592, 603)", "Q6: [206, 241)", - "Q1: [23, 76)", "Q10: [323, 354)", "Q4: [133, 182)", "Q8: [280, 312)"] - - @test cut([0, 1, 1, 1, 1], 2) == - ["Q1: [0, 1)", "Q2: [1, 1]", "Q2: [1, 1]", "Q2: [1, 1]", "Q2: [1, 1]"] - - @test cut([1, 1, 1, 1, 2], 2) == - ["Q1: [1, 2)", "Q1: [1, 2)", "Q1: [1, 2)", "Q1: [1, 2)", "Q2: [2, 2]"] +@testset "cut when quantile falls exactly on a data value" begin + x = cut([11, 14, 43, 54, 54, 56, 73, 79, 84, 84], 3) + @test x == + ["Q1: [11, 54)", "Q1: [11, 54)", "Q1: [11, 54)", + "Q2: [54, 73)", "Q2: [54, 73)", "Q2: [54, 73)", + "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]"] + @test levels(x) == ["Q1: [11, 54)", "Q2: [54, 73)", "Q3: [73, 84]"] end end \ No newline at end of file From 364651fcf3cabb407c72ace9cd2a03bcfd9af9ee Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 27 Apr 2025 23:12:40 +0200 Subject: [PATCH 05/12] Small cleanup --- src/extras.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/extras.jl b/src/extras.jl index 1c5f3596..3f901d63 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -253,10 +253,6 @@ function find_breaks(v::AbstractVector, qs::AbstractVector) q = qs[i] end end - # if last values in x are equal to q, breaks were not initialized - for i in i:n - breaks[i] = q - end return breaks end @@ -295,9 +291,7 @@ function cut(x::AbstractArray, ngroups::Integer; throw(ArgumentError("NaN values are not allowed in input vector")) end qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) - @show qs, min_x, max_x breaks = [min_x; find_breaks(sorted_x, qs); max_x] - @show breaks if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * "too many duplicated values in `x`. " * From bf0b31041f21e83377c3399ec5e5fec7d6859966 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 27 Apr 2025 23:35:42 +0200 Subject: [PATCH 06/12] Fix doctests --- src/extras.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/extras.jl b/src/extras.jl index 3f901d63..11207326 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -77,25 +77,25 @@ julia> cut(-1:0.5:1, [0, 1], extend=true) julia> cut(-1:0.5:1, 2) 5-element CategoricalArray{String,1,UInt32}: - "Q1: [-1.0, 0.5)" - "Q1: [-1.0, 0.5)" - "Q1: [-1.0, 0.5)" - "Q2: [0.5, 1.0]" - "Q2: [0.5, 1.0]" + "Q1: [-1.0, 0.0)" + "Q1: [-1.0, 0.0)" + "Q2: [0.0, 1.0]" + "Q2: [0.0, 1.0]" + "Q2: [0.0, 1.0]" julia> cut(-1:0.5:1, 2, labels=["A", "B"]) 5-element CategoricalArray{String,1,UInt32}: "A" "A" - "A" "B" - "B" + "B" + "B" julia> cut(-1:0.5:1, 2, labels=[-0.5, +0.5]) 5-element CategoricalArray{Float64,1,UInt32}: -0.5 -0.5 - -0.5 + 0.5 0.5 0.5 @@ -106,9 +106,9 @@ julia> cut(-1:0.5:1, 3, labels=fmt) 5-element CategoricalArray{String,1,UInt32}: "grp 1 (-1.0//0.0)" "grp 1 (-1.0//0.0)" - "grp 2 (0.0//1.0)" - "grp 2 (0.0//1.0)" - "grp 3 (1.0//1.0)" + "grp 2 (0.0//0.5)" + "grp 3 (0.5//1.0)" + "grp 3 (0.5//1.0)" ``` """ @inline function cut(x::AbstractArray, breaks::AbstractVector; From daaa0cce7b1b1cac3b67b079da2402b89b50661c Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 27 Apr 2025 23:37:49 +0200 Subject: [PATCH 07/12] Indentation --- test/15_extras.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/15_extras.jl b/test/15_extras.jl index b995e3bf..af4f79f5 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -373,9 +373,9 @@ end @testset "cut when quantile falls exactly on a data value" begin x = cut([11, 14, 43, 54, 54, 56, 73, 79, 84, 84], 3) @test x == - ["Q1: [11, 54)", "Q1: [11, 54)", "Q1: [11, 54)", - "Q2: [54, 73)", "Q2: [54, 73)", "Q2: [54, 73)", - "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]"] + ["Q1: [11, 54)", "Q1: [11, 54)", "Q1: [11, 54)", + "Q2: [54, 73)", "Q2: [54, 73)", "Q2: [54, 73)", + "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]"] @test levels(x) == ["Q1: [11, 54)", "Q2: [54, 73)", "Q3: [73, 84]"] end From 062efb7abd1d0bb137bf1f49cb5c07e1aea53ca5 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Wed, 7 May 2025 22:40:57 +0200 Subject: [PATCH 08/12] Simplify default `cut` labels 1) The quantile number isn't needed in most cases in the label, and anyway it's shown when printing an ordered `CategoricalValue`. Only use it by default when `allowempty=true` to avoid data-dependent errors if there are duplicate levels. 2) Round breaks by default to a number of significant digits chosen by `sigdigits`. This number is increased if necessary for breaks to remain unique. This generates labels which are not completely correct as rounding may make the left break greater than a value which is included in the interval, but this is generally minor and expected. Taking the floor rather than rounding would be more correct, but it can generate unexpected labels due to floating point trickiness (e.g. `floor(0.0003, sigdigits=4)` gives 0.0002999). This is what R does. Add a deprecation to avoid breaking custom `labels` functions which did not accept `sigdigits`. --- Project.toml | 2 +- src/CategoricalArrays.jl | 4 +- src/extras.jl | 201 +++++++++++++++++++++++++++++++-------- test/15_extras.jl | 148 ++++++++++++++++++---------- 4 files changed, 261 insertions(+), 94 deletions(-) diff --git a/Project.toml b/Project.toml index 4593b00b..a83d5d02 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ CategoricalArraysSentinelArraysExt = "SentinelArrays" CategoricalArraysStructTypesExt = "StructTypes" [compat] -Compat = "3.37, 4" +Compat = "3.47, 4.10" DataAPI = "1.6" JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21" JSON3 = "1.1.2" diff --git a/src/CategoricalArrays.jl b/src/CategoricalArrays.jl index a28cba94..e8865c76 100644 --- a/src/CategoricalArrays.jl +++ b/src/CategoricalArrays.jl @@ -11,10 +11,12 @@ module CategoricalArrays import DataAPI: unwrap export unwrap + using Compat + @compat public default_formatter, numbered_formatter + using DataAPI using Missings using Printf - import Compat # JuliaLang/julia#36810 if VERSION < v"1.5.2" diff --git a/src/extras.jl b/src/extras.jl index 11207326..7690c666 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -27,17 +27,56 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray, end end +const CUT_FMT = Printf.Format("%.*g") + +""" + CategoricalArrays.default_formatter(from, to, i::Integer; + leftclosed::Bool, rightclosed::Bool, + sigdigits::Integer) + +Provide the default label format for the `cut(x, breaks)` method, +which is `"[from, to)"` if `leftclosed` is `true` and `"[from, to)"` otherwise. + +If they are floating points values, breaks are turned into to strings using +`@sprintf("%.*g", sigdigits, break)` +(or `to` using `@sprintf("%.*g", sigdigits, break)` for the last break). """ - default_formatter(from, to, i; leftclosed, rightclosed) +function default_formatter(from, to, i::Integer; + leftclosed::Bool, rightclosed::Bool, + sigdigits::Integer) + from_str = from isa AbstractFloat ? + Printf.format(CUT_FMT, sigdigits, from) : + string(from) + to_str = to isa AbstractFloat ? + Printf.format(CUT_FMT, sigdigits, to) : + string(to) + string(leftclosed ? "[" : "(", from_str, ", ", to_str, rightclosed ? "]" : ")") +end -Provide the default label format for the `cut(x, breaks)` method. """ -default_formatter(from, to, i; leftclosed, rightclosed) = - string(leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")") + CategoricalArrays.numbered_formatter(from, to, i::Integer; + leftclosed::Bool, rightclosed::Bool, + sigdigits::Integer) + +Provide the default label format for the `cut(x, ngroups)` method +when `allowempty=true`, which is `"i: [from, to)"` if `leftclosed` +is `true` and `"i: [from, to)"` otherwise. + +If they are floating points values, breaks are turned into to strings using +`@sprintf("%.*g", sigdigits, breaks)` +(or `to` using `@sprintf("%.*g", sigdigits, break)` for the last break). +""" +numbered_formatter(from, to, i::Integer; + leftclosed::Bool, rightclosed::Bool, + sigdigits::Integer) = + string(i, ": ", + default_formatter(from, to, i, leftclosed=leftclosed, rightclosed=rightclosed, + sigdigits=sigdigits)) @doc raw""" cut(x::AbstractArray, breaks::AbstractVector; labels::Union{AbstractVector,Function}, + sigdigits::Integer=3, extend::Union{Bool,Missing}=false, allowempty::Bool=false) Cut a numeric array into intervals at values `breaks` @@ -54,10 +93,15 @@ also accept them. in `x` fall outside of the breaks; when `true`, breaks are automatically added to include all values in `x`; when `missing`, values outside of the breaks generate `missing` entries. * `labels::Union{AbstractVector, Function}`: a vector of strings, characters - or numbers giving the names to use for - the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates + or numbers giving the names to use for the intervals; or a function + `f(from, to, i::Integer; leftclosed::Bool, rightclosed::Bool, sigdigits::Integer)` that generates the labels from the left and right interval boundaries and the group index. Defaults to - `"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`). + [`CategoricalArrays.default_formatter`](@ref), giving `"[from, to)"` (or `"[from, to]"` + for the rightmost interval if `extend == true`). +* `sigdigits::Integer=3`: the minimum number of significant digits to use in labels. + This value is increased automatically if necessary so that rounded breaks are unique. + Only used for floating point types and when `labels` is a function, in which case it + is passed to it as a keyword argument. * `allowempty::Bool=false`: when `false`, an error is raised if some breaks other than the last one appear multiple times, generating empty intervals; when `true`, duplicate breaks are allowed and the intervals they generate are kept as @@ -69,19 +113,19 @@ julia> using CategoricalArrays julia> cut(-1:0.5:1, [0, 1], extend=true) 5-element CategoricalArray{String,1,UInt32}: - "[-1.0, 0.0)" - "[-1.0, 0.0)" - "[0.0, 1.0]" - "[0.0, 1.0]" - "[0.0, 1.0]" + "[-1, 0)" + "[-1, 0)" + "[0, 1]" + "[0, 1]" + "[0, 1]" julia> cut(-1:0.5:1, 2) 5-element CategoricalArray{String,1,UInt32}: - "Q1: [-1.0, 0.0)" - "Q1: [-1.0, 0.0)" - "Q2: [0.0, 1.0]" - "Q2: [0.0, 1.0]" - "Q2: [0.0, 1.0]" + "[-1, 0)" + "[-1, 0)" + "[0, 1]" + "[0, 1]" + "[0, 1]" julia> cut(-1:0.5:1, 2, labels=["A", "B"]) 5-element CategoricalArray{String,1,UInt32}: @@ -114,6 +158,7 @@ julia> cut(-1:0.5:1, 3, labels=fmt) @inline function cut(x::AbstractArray, breaks::AbstractVector; extend::Union{Bool, Missing}=false, labels::Union{AbstractVector{<:SupportedTypes},Function}=default_formatter, + sigdigits::Integer=3, allowmissing::Union{Bool, Nothing}=nothing, allow_missing::Union{Bool, Nothing}=nothing, allowempty::Bool=false) @@ -127,14 +172,15 @@ julia> cut(-1:0.5:1, 3, labels=fmt) :cut) extend = missing end - return _cut(x, breaks, extend, labels, allowempty) + return _cut(x, breaks, extend, labels, sigdigits, allowempty) end # Separate function for inferability (thanks to inlining of cut) function _cut(x::AbstractArray{T, N}, breaks::AbstractVector, extend::Union{Bool, Missing}, labels::Union{AbstractVector{<:SupportedTypes},Function}, - allowempty::Bool=false) where {T, N} + sigdigits::Integer, + allowempty::Bool) where {T, N} if !issorted(breaks) breaks = sort(breaks) end @@ -191,21 +237,55 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector, end end + # Find minimal number of digits so that distinct breaks remain so + if eltype(breaks) <: AbstractFloat + while true + local i + for outer i in 2:lastindex(breaks) + b1 = breaks[i-1] + b2 = breaks[i] + isequal(b1, b2) && continue + + b1_str = Printf.format(CUT_FMT, sigdigits, b1) + b2_str = Printf.format(CUT_FMT, sigdigits, b2) + if b1_str == b2_str + sigdigits += 1 + break + end + end + i == lastindex(breaks) && break + end + end n = length(breaks) n >= 2 || throw(ArgumentError("at least two breaks must be provided when extend is not true")) if labels isa Function from = breaks[1:n-1] to = breaks[2:n] - firstlevel = labels(from[1], to[1], 1, - leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false) + local firstlevel + try + firstlevel = labels(from[1], to[1], 1, + leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false, + sigdigits=sigdigits) + catch + # Support functions defined before v1.0, where sigdigits did not exist + Base.depwarn("`labels` function is now required to accept a `sigdigits` keyword argument", + :cut) + labels_orig = labels + labels = (from, to, i; leftclosed, rightclosed, sigdigits) -> + labels_orig(from, to, i; leftclosed, rightclosed) + firstlevel = labels_orig(from[1], to[1], 1, + leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false) + end levs = Vector{typeof(firstlevel)}(undef, n-1) levs[1] = firstlevel for i in 2:n-2 levs[i] = labels(from[i], to[i], i, - leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false) + leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false, + sigdigits=sigdigits) end levs[end] = labels(from[end], to[end], n-1, - leftclosed=true, rightclosed=true) + leftclosed=true, rightclosed=true, + sigdigits=sigdigits) else length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))")) @@ -225,14 +305,6 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector, CategoricalArray{S, N}(refs, pool) end -""" - quantile_formatter(from, to, i; leftclosed, rightclosed) - -Provide the default label format for the `cut(x, ngroups)` method. -""" -quantile_formatter(from, to, i; leftclosed, rightclosed) = - string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")") - """ Find first value in (sorted) `v` which is greater than or equal to each quantile in (sorted) `qs`. @@ -240,25 +312,30 @@ in (sorted) `qs`. function find_breaks(v::AbstractVector, qs::AbstractVector) n = length(qs) breaks = similar(v, n) - n == 0 && return breaks + breaks_prev = similar(v, n) + n == 0 && return (breaks, breaks_prev) i = 1 q = qs[1] - @inbounds for x in v + @inbounds for j in eachindex(v) + x = v[j] # Use isless and isequal to differentiate -0.0 from 0.0 if isless(q, x) || isequal(q, x) breaks[i] = x + # FIXME : handle duplicated breaks + breaks_prev[i] = v[clamp(j-1, firstindex(v), lastindex(v))] i += 1 i > n && break q = qs[i] end end - return breaks + return (breaks, breaks_prev) end """ cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:AbstractString},Function}, + sigdigits::Integer=3, allowempty::Bool=false) Cut a numeric array into `ngroups` quantiles. @@ -271,17 +348,25 @@ quantiles. # Keyword arguments * `labels::Union{AbstractVector, Function}`: a vector of strings, characters - or numbers giving the names to use for - the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates + or numbers giving the names to use for the intervals; or a function + `f(from, to, i::Integer; leftclosed::Bool, rightclosed::Bool, sigdigits::Integer)` that generates the labels from the left and right interval boundaries and the group index. Defaults to - `"Qi: [from, to)"` (or `"Qi: [from, to]"` for the rightmost interval). + [`CategoricalArrays.default_formatter`](@ref), giving `"[from, to)"` (or `"[from, to]"` + for the rightmost interval if `extend == true`) if `allowempty=false`, otherwise to + [`CategoricalArrays.numbered_formatter`](@ref), which prefixes the label with the quantile + number to ensure uniqueness. +* `sigdigits::Integer=3`: the minimum number of significant digits to use when rounding + breaks for inclusion in generated labels. This value is increased automatically if necessary + so that rounded breaks are unique. Only used for floating point types and when `labels` is a + function, in which case it is passed to it as a keyword argument. * `allowempty::Bool=false`: when `false`, an error is raised if some quantiles breakpoints other than the last one are equal, generating empty intervals; when `true`, duplicate breaks are allowed and the intervals they generate are kept as unused levels (but duplicate labels are not allowed). """ function cut(x::AbstractArray, ngroups::Integer; - labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter, + labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing, + sigdigits::Integer=3, allowempty::Bool=false) ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)")) sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) @@ -291,12 +376,48 @@ function cut(x::AbstractArray, ngroups::Integer; throw(ArgumentError("NaN values are not allowed in input vector")) end qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) - breaks = [min_x; find_breaks(sorted_x, qs); max_x] + breaks, breaks_prev = find_breaks(sorted_x, qs) + breaks = [min_x; breaks; max_x] if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * "too many duplicated values in `x`. " * "Pass `allowempty=true` to allow empty quantiles or " * "choose a lower value for `ngroups`.")) end - cut(x, breaks; labels=labels, allowempty=allowempty) + if labels === nothing + labels = allowempty ? numbered_formatter : default_formatter + + if eltype(breaks) <: AbstractFloat + while true + local i + for outer i in 2:lastindex(breaks) + b1 = breaks[i-1] + b2 = breaks[i] + isequal(b1, b2) && continue + + # Find minimal number of digits so that `floor` does not + # return a value that is lower than value immediately below break + # We skip the first break, which is the minimum and has no equivalent + # in `breaks_prev` + b1_rounded = round(b1, sigdigits=sigdigits) + b2_rounded = round(b2, sigdigits=sigdigits) + if i < lastindex(breaks) && + (isequal(b2_rounded, breaks_prev[i-1]) || isless(b2_rounded, breaks_prev[i-1])) + sigdigits += 1 + break + end + + # Find minimal number of digits so that breaks are unique + b1_str = Printf.format(CUT_FMT, sigdigits, b1) + b2_str = Printf.format(CUT_FMT, sigdigits, b2) + if b1_str == b2_str + sigdigits += 1 + break + end + end + i == lastindex(breaks) && break + end + end + end + return cut(x, breaks; labels=labels, sigdigits=sigdigits, allowempty=allowempty) end diff --git a/test/15_extras.jl b/test/15_extras.jl index af4f79f5..c45ee5ab 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -93,10 +93,10 @@ const ≅ = isequal @test levels(x) == ["b", "a"] x = @inferred cut(Matrix{Union{Float64, T}}([-1.1 3.0; 1.456 10.394]), [-2.134, 3.0, 12.5]) - @test x == ["[-2.134, 3.0)" "[3.0, 12.5]"; "[-2.134, 3.0)" "[3.0, 12.5]"] + @test x == ["[-2.1, 3)" "[3, 12]"; "[-2.1, 3)" "[3, 12]"] @test isa(x, CategoricalMatrix{Union{String, T}}) @test isordered(x) - @test levels(x) == ["[-2.134, 3.0)", "[3.0, 12.5]"] + @test levels(x) == ["[-2.1, 3)", "[3, 12]"] labels = 0:2:8 x = @inferred cut(Vector{Union{T, Int}}(1:8), 0:2:10, labels=labels) @@ -127,18 +127,18 @@ end @testset "cut([5, 4, 3, 2], 2)" begin x = @inferred cut([5, 4, 3, 2], 2) - @test x == ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", "Q1: [2, 4)"] + @test x == ["[4, 5]", "[4, 5]", "[2, 4)", "[2, 4)"] @test isa(x, CategoricalArray) @test isordered(x) - @test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"] + @test levels(x) == ["[2, 4)", "[4, 5]"] end @testset "cut(x, n) with missing values" begin x = @inferred cut([5, 4, 3, missing, 2], 2) - @test x ≅ ["Q2: [4, 5]", "Q2: [4, 5]", "Q1: [2, 4)", missing, "Q1: [2, 4)"] + @test x ≅ ["[4, 5]", "[4, 5]", "[2, 4)", missing, "[2, 4)"] @test isa(x, CategoricalArray) @test isordered(x) - @test levels(x) == ["Q1: [2, 4)", "Q2: [4, 5]"] + @test levels(x) == ["[2, 4)", "[4, 5]"] end @testset "cut(x, n) with invalid n" begin @@ -147,7 +147,7 @@ end end @testset "cut with formatter function" begin - my_formatter(from, to, i; leftclosed, rightclosed) = "$i: $from -- $to" + my_formatter(from, to, i; leftclosed, rightclosed, sigdigits) = "$i: $from -- $to" x = 0.15:0.20:0.95 p = [0, 0.4, 0.8, 1.0] @@ -155,20 +155,24 @@ end a = @inferred cut(x, p, labels=my_formatter) @test a == ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"] + my_old_formatter(from, to, i; leftclosed, rightclosed) = "$i: $from -- $to" + a = @test_deprecated r"`labels`.*" cut(x, p, labels=my_old_formatter) + @test a == ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"] + # GH 274 - my_formatter_2(from, to, i; leftclosed, rightclosed) = "$i: $(from+1) -- $(to+1)" + my_formatter_2(from, to, i; leftclosed, rightclosed, sigdigits) = "$i: $(from+1) -- $(to+1)" a = @inferred cut(x, p, labels=my_formatter_2) @test a == ["1: 1.0 -- 1.4", "1: 1.0 -- 1.4", "2: 1.4 -- 1.8", "2: 1.4 -- 1.8", "3: 1.8 -- 2.0"] for T in (Union{}, Missing) - labels = (from, to, i; leftclosed, rightclosed) -> (to+from)/2 + labels = (from, to, i; leftclosed, rightclosed, sigdigits) -> (to+from)/2 a = @inferred cut(Vector{Union{T, Int}}(1:8), 0:2:10, labels=labels) @test a == [1.0, 3.0, 3.0, 5.0, 5.0, 7.0, 7.0, 9.0] @test isa(a, CategoricalVector{Union{Float64, T}}) @test isordered(a) @test levels(a) == [1.0, 3.0, 5.0, 7.0, 9.0] - labels = (from, to, i; leftclosed, rightclosed) -> "$((to+from)/2)" + labels = (from, to, i; leftclosed, rightclosed, sigdigits) -> "$((to+from)/2)" a = @inferred cut(Vector{Union{T, Int}}(1:8), 0:2:10, labels=labels) @test a == string.([1.0, 3.0, 3.0, 5.0, 5.0, 7.0, 7.0, 9.0]) @test isa(a, CategoricalVector{Union{String, T}}) @@ -188,8 +192,8 @@ end @test_throws ArgumentError cut(x, [0, 0.1, 0.1, 10]) @test_throws ArgumentError cut(x, 10) y = cut(x, [0, 0.1, 10, 10]) - @test y == [fill("[0.0, 0.1)", 10); fill("[0.1, 10.0)", 10)] - @test levels(y) == ["[0.0, 0.1)", "[0.1, 10.0)", "[10.0, 10.0]"] + @test y == [fill("[0, 0.1)", 10); fill("[0.1, 10)", 10)] + @test levels(y) == ["[0, 0.1)", "[0.1, 10)", "[10, 10]"] @test_throws ArgumentError cut(1:10, [1, 5, 5, 11]) y = cut(1:10, [1, 5, 5, 11], allowempty=true) @@ -251,55 +255,55 @@ end @test_throws ArgumentError cut(1:8, 0:2:10, labels=[0, 1, 1, 2, 3]) @test_throws ArgumentError cut(1:8, [0, 2, 2, 6, 8, 10], labels=[0, 1, 1, 2, 3], allowempty=true) - fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0) + fmt = (from, to, i; leftclosed, rightclosed, sigdigits) -> (i % 2 == 0 ? to : 0.0) @test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt) @test_throws ArgumentError cut([fill(1, 10); 4], 2) x = cut([fill(1, 10); 4], 2, allowempty=true) - @test unique(x) == ["Q2: [1, 4]"] - @test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4]"] + @test unique(x) == ["2: [1, 4]"] + @test levels(x) == ["1: (1, 1)", "2: [1, 4]"] @test_throws ArgumentError cut([fill(1, 10); 4], 3) x = cut([fill(1, 10); 4], 3, allowempty=true) - @test unique(x) == ["Q3: [1, 4]"] - @test levels(x) == ["Q1: (1, 1)", "Q2: (1, 1)", "Q3: [1, 4]"] + @test unique(x) == ["3: [1, 4]"] + @test levels(x) == ["1: (1, 1)", "2: (1, 1)", "3: [1, 4]"] x = cut([fill(4, 10); 1], 2) - @test x == [fill("Q2: [4, 4]", 10); "Q1: [1, 4)"] - @test levels(x) == ["Q1: [1, 4)"; "Q2: [4, 4]"] + @test x == [fill("[4, 4]", 10); "[1, 4)"] + @test levels(x) == ["[1, 4)"; "[4, 4]"] @test_throws ArgumentError cut([fill(4, 10); 1], 3) x = cut([fill(4, 10); 1], 3, allowempty=true) - @test x == [fill("Q3: [4, 4]", 10); "Q1: [1, 4)"] - @test levels(x) == ["Q1: [1, 4)", "Q2: (4, 4)", "Q3: [4, 4]"] + @test x == [fill("3: [4, 4]", 10); "1: [1, 4)"] + @test levels(x) == ["1: [1, 4)", "2: (4, 4)", "3: [4, 4]"] x = cut([fill(1, 5); fill(4, 5)], 2) - @test x == [fill("Q1: [1, 4)", 5); fill("Q2: [4, 4]", 5)] - @test levels(x) == ["Q1: [1, 4)", "Q2: [4, 4]"] + @test x == [fill("[1, 4)", 5); fill("[4, 4]", 5)] + @test levels(x) == ["[1, 4)", "[4, 4]"] @test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3) x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true) - @test x == [fill("Q2: [1, 4)", 5); fill("Q3: [4, 4]", 5)] - @test levels(x) == ["Q1: (1, 1)", "Q2: [1, 4)", "Q3: [4, 4]"] + @test x == [fill("2: [1, 4)", 5); fill("3: [4, 4]", 5)] + @test levels(x) == ["1: (1, 1)", "2: [1, 4)", "3: [4, 4]"] end @testset "cut with -0.0" begin x = cut([-0.0, 0.0, 0.0, -0.0], 2) - @test x == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]", "Q2: [0.0, 0.0]", "Q1: [-0.0, 0.0)"] - @test levels(x) == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]"] + @test x == ["[-0, 0)", "[0, 0]", "[0, 0]", "[-0, 0)"] + @test levels(x) == ["[-0, 0)", "[0, 0]"] x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0]) - @test x == ["[-0.0, 0.0)", "[0.0, 0.0]", "[0.0, 0.0]", "[-0.0, 0.0)"] - @test levels(x) == ["[-0.0, 0.0)", "[0.0, 0.0]"] + @test x == ["[-0, 0)", "[0, 0]", "[0, 0]", "[-0, 0)"] + @test levels(x) == ["[-0, 0)", "[0, 0]"] x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0]) - @test x == fill("[-0.0, 0.0]", 4) - @test levels(x) == ["[-0.0, 0.0]"] + @test x == fill("[-0, 0]", 4) + @test levels(x) == ["[-0, 0]"] x = cut([-0.0, 0.0, 0.0, -0.0], [0.0], extend=true) - @test x == fill("[-0.0, 0.0]", 4) - @test levels(x) == ["[-0.0, 0.0]"] + @test x == fill("[-0, 0]", 4) + @test levels(x) == ["[-0, 0]"] x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0], extend=true) - @test x == fill("[-0.0, 0.0]", 4) - @test levels(x) == ["[-0.0, 0.0]"] + @test x == fill("[-0, 0]", 4) + @test levels(x) == ["[-0, 0]"] x = cut([-0.0, 0.0, 0.0, -0.0], 2, labels=[-0.0, 0.0]) @test x == [-0.0, 0.0, 0.0, -0.0] @@ -336,7 +340,7 @@ end @test levels(x) == [-0.0, 0.0] x = @inferred cut(-1:0.5:1, [0, 1], extend=true) - @test x == ["[-1.0, 0.0)", "[-1.0, 0.0)", "[0.0, 1.0]", "[0.0, 1.0]", "[0.0, 1.0]"] + @test x == ["[-1, 0)", "[-1, 0)", "[0, 1]", "[0, 1]", "[0, 1]"] end @testset "cut with NaN and Inf" begin @@ -346,37 +350,77 @@ end @test_throws ArgumentError("NaN values are not allowed in breaks") cut([1, 2], [1, NaN]) x = cut([1, Inf], [1], extend=true) - @test x ≅ ["[1.0, Inf]", "[1.0, Inf]"] - @test levels(x) == ["[1.0, Inf]"] + @test x ≅ ["[1, Inf]", "[1, Inf]"] + @test levels(x) == ["[1, Inf]"] x = cut([1, -Inf], [1], extend=true) - @test x ≅ ["[-Inf, 1.0]", "[-Inf, 1.0]"] - @test levels(x) == ["[-Inf, 1.0]"] + @test x ≅ ["[-Inf, 1]", "[-Inf, 1]"] + @test levels(x) == ["[-Inf, 1]"] x = cut([1:5; Inf], [1, 2, Inf]) - @test x ≅ ["[1.0, 2.0)"; fill("[2.0, Inf]", 5)] - @test levels(x) == ["[1.0, 2.0)", "[2.0, Inf]"] + @test x ≅ ["[1, 2)"; fill("[2, Inf]", 5)] + @test levels(x) == ["[1, 2)", "[2, Inf]"] x = cut([1:5; -Inf], [-Inf, 2, 5]) - @test x ≅ ["[-Inf, 2.0)"; fill("[2.0, 5.0]", 4); "[-Inf, 2.0)"] - @test levels(x) == ["[-Inf, 2.0)", "[2.0, 5.0]"] + @test x ≅ ["[-Inf, 2)"; fill("[2, 5]", 4); "[-Inf, 2)"] + @test levels(x) == ["[-Inf, 2)", "[2, 5]"] x = cut([1:5; Inf], 2) - @test x ≅ [fill("Q1: [1.0, 4.0)", 3); fill("Q2: [4.0, Inf]", 3)] - @test levels(x) == ["Q1: [1.0, 4.0)", "Q2: [4.0, Inf]"] + @test x ≅ [fill("[1, 4)", 3); fill("[4, Inf]", 3)] + @test levels(x) == ["[1, 4)", "[4, Inf]"] x = cut([1:5; -Inf], 2) - @test x ≅ [fill("Q1: [-Inf, 3.0)", 2); fill("Q2: [3.0, 5.0]", 3); "Q1: [-Inf, 3.0)"] - @test levels(x) == ["Q1: [-Inf, 3.0)", "Q2: [3.0, 5.0]"] + @test x ≅ [fill("[-Inf, 3)", 2); fill("[3, 5]", 3); "[-Inf, 3)"] + @test levels(x) == ["[-Inf, 3)", "[3, 5]"] end @testset "cut when quantile falls exactly on a data value" begin x = cut([11, 14, 43, 54, 54, 56, 73, 79, 84, 84], 3) @test x == - ["Q1: [11, 54)", "Q1: [11, 54)", "Q1: [11, 54)", - "Q2: [54, 73)", "Q2: [54, 73)", "Q2: [54, 73)", - "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]", "Q3: [73, 84]"] - @test levels(x) == ["Q1: [11, 54)", "Q2: [54, 73)", "Q3: [73, 84]"] + ["[11, 54)", "[11, 54)", "[11, 54)", + "[54, 73)", "[54, 73)", "[54, 73)", + "[73, 84]", "[73, 84]", "[73, 84]", "[73, 84]"] + @test levels(x) == ["[11, 54)", "[54, 73)", "[73, 84]"] +end + +@testset "cut computation of sigdigits" begin + x = cut([1.2, 1.3, 2], 2) + @test levels(x) == ["[1.2, 1.3)", "[1.3, 2]"] + + x = cut([1.0, 2.0, 3.0], 2) + @test levels(x) == ["[1, 2)", "[2, 3]"] + + x = cut([1.00002, 1.00003, 2], 2) + @test levels(x) == ["[1.00002, 1.00003)", "[1.00003, 2]"] + + x = cut([1.00002, 1.00003, 1.00005, 2], 2) + @test levels(x) == ["[1, 1.0001)", "[1.0001, 2]"] + + x = cut([1.00001, 1.00002, 1.00002, 2], 2) + @test levels(x) == ["[1.00001, 1.00002)", "[1.00002, 2]"] + + x = cut([1.00001, 1.00003, 1.1, 2], 2) + @test levels(x) == ["[1, 1.1)", "[1.1, 2]"] + + # @sprintf with %g uses scientific notation even in some cases + # where classic notation would be shorter + x = cut([1.0, 10.0, 100.0, 1000.0], [1.0, 10.0, 100.0, 1000.0]) + @test levels(x) == ["[1, 10)", "[10, 100)", "[100, 1e+03]"] + # But integers are rendered using plain `string` + x = cut([1, 10, 100], [1, 10, 100, 1000]) + @test levels(x) == ["[1, 10)", "[10, 100)", "[100, 1000]"] + + # Extreme case + x = cut([8.85718832925723e-7, 8.572446994052413e-7, 1.40217695121027e-7, 8.966449714804087e-7, + 3.070384341319470e-7, 3.070384341319471e-7, 1.8520709563325888e-7, 5.630461710066611e-7, + 6.781422109070843e-7, 4.776113711396994e-7, 0.2538909094146984, 0.5249665525921473, + 0.8321957380046366, 0.9648282851978118, 0.36084175275805797, 0.7851054639425253, + 0.6875195857202754, 0.614940093507575, 0.6224944997292978, 0.6055683461790675, + 5.349085340927365e11, 1.3471583229449602e11, 6.538893396835975e11, 4.826316844547661e11, + 8.803607035550856e11, 1.8174694671397316e10, 1.6709745443719125e11, 3.2050577954311835e11, + 1.6134999167460663e11, 7.396308745225059e11], 3) + @test levels(x) == ["[1.4e-07, 0.25)", "[0.25, 1.8e+10)", "[1.8e+10, 8.8e+11]"] + end end \ No newline at end of file From e1acb3829e62fe57ca263f61e2a8504374cc329e Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 17 May 2025 10:42:35 +0200 Subject: [PATCH 09/12] Simplify logic --- src/extras.jl | 55 +++++++++++++-------------------------------------- 1 file changed, 14 insertions(+), 41 deletions(-) diff --git a/src/extras.jl b/src/extras.jl index 7690c666..cbc5e14c 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -88,6 +88,11 @@ the last interval, which is closed on both ends, i.e. `[lower, upper]`. If `x` accepts missing values (i.e. `eltype(x) >: Missing`) the returned array will also accept them. +!!! note + For floating point data, breaks may be rounded to `sigdigits` significant digits + when generating interval labels, meaning that they may not reflect exactly the cutpoints + used. + # Keyword arguments * `extend::Union{Bool, Missing}=false`: when `false`, an error is raised if some values in `x` fall outside of the breaks; when `true`, breaks are automatically added to include @@ -312,24 +317,20 @@ in (sorted) `qs`. function find_breaks(v::AbstractVector, qs::AbstractVector) n = length(qs) breaks = similar(v, n) - breaks_prev = similar(v, n) - n == 0 && return (breaks, breaks_prev) + n == 0 && return breaks i = 1 q = qs[1] - @inbounds for j in eachindex(v) - x = v[j] + @inbounds for x in v # Use isless and isequal to differentiate -0.0 from 0.0 if isless(q, x) || isequal(q, x) breaks[i] = x - # FIXME : handle duplicated breaks - breaks_prev[i] = v[clamp(j-1, firstindex(v), lastindex(v))] i += 1 i > n && break q = qs[i] end end - return (breaks, breaks_prev) + return breaks end """ @@ -346,6 +347,11 @@ but breaks are taken from actual data values instead of estimated quantiles. If `x` contains `missing` values, they are automatically skipped when computing quantiles. +!!! note + For floating point data, breaks may be rounded to `sigdigits` significant digits + when generating interval labels, meaning that they may not reflect exactly the cutpoints + used. + # Keyword arguments * `labels::Union{AbstractVector, Function}`: a vector of strings, characters or numbers giving the names to use for the intervals; or a function @@ -376,8 +382,7 @@ function cut(x::AbstractArray, ngroups::Integer; throw(ArgumentError("NaN values are not allowed in input vector")) end qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) - breaks, breaks_prev = find_breaks(sorted_x, qs) - breaks = [min_x; breaks; max_x] + breaks = [min_x; find_breaks(sorted_x, qs); max_x] if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * "too many duplicated values in `x`. " * @@ -386,38 +391,6 @@ function cut(x::AbstractArray, ngroups::Integer; end if labels === nothing labels = allowempty ? numbered_formatter : default_formatter - - if eltype(breaks) <: AbstractFloat - while true - local i - for outer i in 2:lastindex(breaks) - b1 = breaks[i-1] - b2 = breaks[i] - isequal(b1, b2) && continue - - # Find minimal number of digits so that `floor` does not - # return a value that is lower than value immediately below break - # We skip the first break, which is the minimum and has no equivalent - # in `breaks_prev` - b1_rounded = round(b1, sigdigits=sigdigits) - b2_rounded = round(b2, sigdigits=sigdigits) - if i < lastindex(breaks) && - (isequal(b2_rounded, breaks_prev[i-1]) || isless(b2_rounded, breaks_prev[i-1])) - sigdigits += 1 - break - end - - # Find minimal number of digits so that breaks are unique - b1_str = Printf.format(CUT_FMT, sigdigits, b1) - b2_str = Printf.format(CUT_FMT, sigdigits, b2) - if b1_str == b2_str - sigdigits += 1 - break - end - end - i == lastindex(breaks) && break - end - end end return cut(x, breaks; labels=labels, sigdigits=sigdigits, allowempty=allowempty) end From 311e593bbd3faa92763c7552c9f96fc34ab4ad32 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 17 May 2025 13:57:11 +0200 Subject: [PATCH 10/12] Fix test --- test/15_extras.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/15_extras.jl b/test/15_extras.jl index c45ee5ab..5df7860b 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -93,10 +93,10 @@ const ≅ = isequal @test levels(x) == ["b", "a"] x = @inferred cut(Matrix{Union{Float64, T}}([-1.1 3.0; 1.456 10.394]), [-2.134, 3.0, 12.5]) - @test x == ["[-2.1, 3)" "[3, 12]"; "[-2.1, 3)" "[3, 12]"] + @test x == ["[-2.13, 3)" "[3, 12.5]"; "[-2.13, 3)" "[3, 12.5]"] @test isa(x, CategoricalMatrix{Union{String, T}}) @test isordered(x) - @test levels(x) == ["[-2.1, 3)", "[3, 12]"] + @test levels(x) == ["[-2.13, 3)", "[3, 12.5]"] labels = 0:2:8 x = @inferred cut(Vector{Union{T, Int}}(1:8), 0:2:10, labels=labels) @@ -419,7 +419,7 @@ end 5.349085340927365e11, 1.3471583229449602e11, 6.538893396835975e11, 4.826316844547661e11, 8.803607035550856e11, 1.8174694671397316e10, 1.6709745443719125e11, 3.2050577954311835e11, 1.6134999167460663e11, 7.396308745225059e11], 3) - @test levels(x) == ["[1.4e-07, 0.25)", "[0.25, 1.8e+10)", "[1.8e+10, 8.8e+11]"] + @test levels(x) == ["[1.4e-07, 0.254)", "[0.254, 1.82e+10)", "[1.82e+10, 8.8e+11]"] end From 1d24d8458b62864a6e26429638d87d505ccbd645 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 17 May 2025 15:02:55 +0200 Subject: [PATCH 11/12] Support weighted quantiles in `cut` This requires adding an extension point for StatsBase. Unfortunately more copies of the data and weights are done than necessary as StatsBase does not support in-place weighted quantile! on pre-sorted data nor taking a view of weights vectors (JuliaStats/StatsBase.jl#723). --- Project.toml | 6 +++- ext/CategoricalArraysStatsBaseExt.jl | 13 ++++++++ src/CategoricalArrays.jl | 1 + src/extras.jl | 44 +++++++++++++++++++++++----- test/15_extras.jl | 24 +++++++++++++++ 5 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 ext/CategoricalArraysStatsBaseExt.jl diff --git a/Project.toml b/Project.toml index 83c8fb18..c9fb9563 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" @@ -23,6 +24,7 @@ StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" CategoricalArraysArrowExt = "Arrow" CategoricalArraysJSONExt = "JSON" CategoricalArraysRecipesBaseExt = "RecipesBase" +CategoricalArraysStatsBaseExt = "StatsBase" CategoricalArraysSentinelArraysExt = "SentinelArrays" CategoricalArraysStructTypesExt = "StructTypes" @@ -37,6 +39,7 @@ RecipesBase = "1.1" Requires = "1" SentinelArrays = "1" Statistics = "1" +StatsBase = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34" StructTypes = "1" julia = "1.6" @@ -49,8 +52,9 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Arrow", "Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StructTypes", "Test"] +test = ["Arrow", "Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StatsBase", "StructTypes", "Test"] diff --git a/ext/CategoricalArraysStatsBaseExt.jl b/ext/CategoricalArraysStatsBaseExt.jl new file mode 100644 index 00000000..8cbd5c61 --- /dev/null +++ b/ext/CategoricalArraysStatsBaseExt.jl @@ -0,0 +1,13 @@ +module CategoricalArraysStatsBaseExt + +if isdefined(Base, :get_extension) + import CategoricalArrays: _wquantile + using StatsBase +else + import ..CategoricalArrays: _wquantile + using ..StatsBase +end + +_wquantile(x::AbstractArray, w::AbstractWeights, p::AbstractVector) = quantile(x, w, p) + +end diff --git a/src/CategoricalArrays.jl b/src/CategoricalArrays.jl index 272f9b5a..b19a7a4a 100644 --- a/src/CategoricalArrays.jl +++ b/src/CategoricalArrays.jl @@ -47,6 +47,7 @@ module CategoricalArrays @require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl") @require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl") @require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl") + @require StatsBase="2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("../ext/CategoricalArraysStatsBaseExt.jl") @require StructTypes="856f2bd8-1eba-4b0a-8007-ebc267875bd4" include("../ext/CategoricalArraysStructTypesExt.jl") end end diff --git a/src/extras.jl b/src/extras.jl index cbc5e14c..9b631676 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -333,11 +333,17 @@ function find_breaks(v::AbstractVector, qs::AbstractVector) return breaks end +# AbstractWeights method is defined in StatsBase extension +# There is no in-place weighted quantile method in StatsBase +_wquantile(x::AbstractArray, w::AbstractVector, p::AbstractVector) = + throw(ArgumentError("`weights` must be an `AbstractWeights` vector from StatsBase.jl")) + """ cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:AbstractString},Function}, sigdigits::Integer=3, - allowempty::Bool=false) + allowempty::Bool=false, + weights::Union{AbstractWeights, Nothing}=nothing) Cut a numeric array into `ngroups` quantiles. @@ -369,19 +375,41 @@ quantiles. other than the last one are equal, generating empty intervals; when `true`, duplicate breaks are allowed and the intervals they generate are kept as unused levels (but duplicate labels are not allowed). +* `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to used when + computing quantiles (see `quantile` documentation in StatsBase). """ function cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing, sigdigits::Integer=3, - allowempty::Bool=false) + allowempty::Bool=false, + weights::Union{AbstractVector, Nothing}=nothing) ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)")) - sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) - min_x, max_x = first(sorted_x), last(sorted_x) - if (min_x isa Number && isnan(min_x)) || - (max_x isa Number && isnan(max_x)) - throw(ArgumentError("NaN values are not allowed in input vector")) + if weights === nothing + sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) + min_x, max_x = first(sorted_x), last(sorted_x) + if (min_x isa Number && isnan(min_x)) || + (max_x isa Number && isnan(max_x)) + throw(ArgumentError("NaN values are not allowed in input vector")) + end + qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) + else + if eltype(x) >: Missing + nm_inds = findall(!ismissing, x) + nm_x = view(x, nm_inds) + # TODO: use a view once this is supported (JuliaStats/StatsBase.jl#723) + nm_weights = weights[nm_inds] + else + nm_x = x + nm_weights = weights + end + sorted_x = sort(nm_x) + min_x, max_x = first(sorted_x), last(sorted_x) + if (min_x isa Number && isnan(min_x)) || + (max_x isa Number && isnan(max_x)) + throw(ArgumentError("NaN values are not allowed in input vector")) + end + qs = _wquantile(nm_x, nm_weights, (1:(ngroups-1))/ngroups) end - qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) breaks = [min_x; find_breaks(sorted_x, qs); max_x] if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * diff --git a/test/15_extras.jl b/test/15_extras.jl index 5df7860b..80dc14b7 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -1,6 +1,8 @@ module TestExtras using Test using CategoricalArrays +using StatsBase +using Missings const ≅ = isequal @@ -423,4 +425,26 @@ end end +@testset "cut with weighted quantiles" begin + @test_throws ArgumentError cut(1:3, 3, weights=1:3) + + x = collect(Float64, 1:100) + w = fweights(repeat(1:10, inner=10)) + y = cut(x, 10, weights=w) + @test levelcode.(y) == levelcode.(cut(x, quantile(x, w, (0:10)./10))) + @test levels(y) == ["[1, 29)", "[29, 43)", "[43, 53)", "[53, 62)", "[62, 70)", + "[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"] + + mx = allowmissing(x) + mx[2] = mx[10] = missing + nm_inds = .!ismissing.(mx) + y = cut(mx, 10, weights=w) + @test levelcode.(y) ≅ levelcode.(cut(mx, quantile(x[nm_inds], w[nm_inds], (0:10)./10))) + @test levels(y) == ["[1, 30)", "[30, 43)", "[43, 53)", "[53, 62)", "[62, 70)", + "[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"] + + x[5] = NaN + @test_throws ArgumentError cut(x, 3, weights=w) +end + end \ No newline at end of file From 085bd6069e908f55da39309a1268b7c12649772e Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Wed, 21 May 2025 19:25:38 +0200 Subject: [PATCH 12/12] Fix CI --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cf06ecdf..a9262e93 100644 --- a/Project.toml +++ b/Project.toml @@ -57,4 +57,4 @@ StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StructTypes", "Test"] +test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StatsBase", "StructTypes", "Test"]