Skip to content

Commit 9dd935b

Browse files
committed
WIP
1 parent bbed3a2 commit 9dd935b

File tree

2 files changed

+73
-16
lines changed

2 files changed

+73
-16
lines changed

src/extras.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -233,20 +233,28 @@ Provide the default label format for the `cut(x, ngroups)` method.
233233
quantile_formatter(from, to, i; leftclosed, rightclosed) =
234234
string("Q", i, ": ", leftclosed ? "[" : "(", from, ", ", to, rightclosed ? "]" : ")")
235235

236-
function _quantile!(v::AbstractVector, ps::AbstractVector)
237-
n = length(v)
238-
n > 0 || throw(ArgumentError("cannot compute quantiles of empty data vector"))
239-
return map(ps) do p
240-
i = clamp(ceil(Int, n*p), 1, n) + firstindex(v) - 1
241-
q = v[i]
242-
# Take next distinct value even if quantile falls in a series of duplicated values
243-
@inbounds for j in (i+1):lastindex(v)
244-
q_prev = q
245-
q = v[j]
246-
q_prev != q && break
236+
"""Find first value in data which is greater than each quantile in ``qs``."""
237+
function find_breaks(v::AbstractVector, qs::AbstractVector)
238+
n = length(qs)
239+
breaks = similar(v, n)
240+
n == 0 && return breaks
241+
242+
i = 1
243+
q = qs[1]
244+
@inbounds for x in v
245+
if x > q
246+
breaks[i] = x
247+
i += 1
248+
i > n && break
249+
q = qs[i]
247250
end
248-
return q
249251
end
252+
# if last values in x are equal to q, breaks were not initialized
253+
for i in i:n
254+
breaks[i] = q
255+
end
256+
@show breaks
257+
return breaks
250258
end
251259

252260
"""
@@ -279,14 +287,16 @@ function cut(x::AbstractArray, ngroups::Integer;
279287
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
280288
allowempty::Bool=false)
281289
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
282-
xnm = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
283-
min_x, max_x = first(xnm), last(xnm)
290+
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
291+
min_x, max_x = first(sorted_x), last(sorted_x)
284292
if (min_x isa Number && isnan(min_x)) ||
285293
(max_x isa Number && isnan(max_x))
286294
throw(ArgumentError("NaN values are not allowed in input vector"))
287295
end
288-
qs = _quantile!(xnm, (1:(ngroups-1))/ngroups)
289-
breaks = [min_x; qs; max_x]
296+
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
297+
@show qs, min_x, max_x
298+
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
299+
@show breaks
290300
if !allowempty && !allunique(@view breaks[1:end-1])
291301
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
292302
"too many duplicated values in `x`. " *

test/15_extras.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,4 +361,51 @@ end
361361
@test levels(x) == ["Q1: [-Inf, 3.0)", "Q2: [3.0, 5.0]"]
362362
end
363363

364+
@testset "cut in corner cases" begin
365+
# In this case, cut(x, quantile(x, (0:36)/36)) in R generates
366+
# an empty "(143,172]" level
367+
# and qcut(x, 36) in Polars misses that level.
368+
# Our approach uses different breaks at 143 and 182
369+
x = [23, 23, 60, 76, 84, 95, 101, 108, 111, 133, 137, 143, 143, 143, 182,
370+
206, 214, 241, 258, 262, 280, 289, 303, 312, 321, 323, 352, 353, 354,
371+
368, 369, 373, 374, 384, 385, 386, 387, 392, 405, 406, 410, 421, 430,
372+
430, 431, 442, 464, 474, 478, 479, 496, 516, 530, 534, 549, 554, 568,
373+
575, 589, 590, 591, 592, 595, 596, 603, 625, 632, 632, 638, 640, 640,
374+
645, 648, 690, 704, 748, 758, 771, 772, 803, 835, 839, 853, 869, 873,
375+
874, 887, 911, 920, 923, 928, 933, 943, 945, 945, 947, 951, 965, 978, 980]
376+
377+
@test cut(x, 36) ==
378+
["Q17: [442, 474)", "Q8: [280, 312)", "Q2: [76, 101)", "Q9: [312, 323)",
379+
"Q14: [387, 406)", "Q30: [835, 869)", "Q17: [442, 474)", "Q35: [947, 965)",
380+
"Q2: [76, 101)", "Q11: [354, 373)", "Q32: [887, 923)", "Q12: [373, 385)",
381+
"Q24: [603, 638)", "Q29: [772, 835)", "Q24: [603, 638)", "Q15: [406, 430)",
382+
"Q11: [354, 373)", "Q23: [592, 603)", "Q3: [101, 133)", "Q16: [430, 442)",
383+
"Q34: [933, 947)", "Q27: [648, 748)", "Q28: [748, 772)", "Q28: [748, 772)",
384+
"Q21: [568, 589)", "Q18: [474, 496)", "Q32: [887, 923)", "Q11: [354, 373)",
385+
"Q7: [241, 280)", "Q3: [101, 133)", "Q19: [496, 534)", "Q13: [385, 387)",
386+
"Q36: [965, 980]", "Q33: [923, 933)", "Q16: [430, 442)", "Q36: [965, 980]",
387+
"Q27: [648, 748)", "Q24: [603, 638)", "Q32: [887, 923)", "Q4: [133, 182)",
388+
"Q22: [589, 592)", "Q1: [23, 76)", "Q5: [182, 206)", "Q28: [748, 772)",
389+
"Q30: [835, 869)", "Q31: [869, 887)", "Q22: [589, 592)", "Q15: [406, 430)",
390+
"Q31: [869, 887)", "Q19: [496, 534)", "Q26: [640, 648)", "Q8: [280, 312)",
391+
"Q18: [474, 496)", "Q14: [387, 406)", "Q7: [241, 280)", "Q30: [835, 869)",
392+
"Q3: [101, 133)", "Q9: [312, 323)", "Q4: [133, 182)", "Q24: [603, 638)",
393+
"Q20: [534, 568)", "Q25: [638, 640)", "Q20: [534, 568)", "Q23: [592, 603)",
394+
"Q12: [373, 385)", "Q27: [648, 748)", "Q29: [772, 835)", "Q6: [206, 241)",
395+
"Q34: [933, 947)", "Q16: [430, 442)", "Q26: [640, 648)", "Q15: [406, 430)",
396+
"Q12: [373, 385)", "Q26: [640, 648)", "Q19: [496, 534)", "Q4: [133, 182)",
397+
"Q34: [933, 947)", "Q31: [869, 887)", "Q10: [323, 354)", "Q21: [568, 589)",
398+
"Q4: [133, 182)", "Q7: [241, 280)", "Q35: [947, 965)", "Q14: [387, 406)",
399+
"Q18: [474, 496)", "Q34: [933, 947)", "Q20: [534, 568)", "Q22: [589, 592)",
400+
"Q33: [923, 933)", "Q10: [323, 354)", "Q13: [385, 387)", "Q1: [23, 76)",
401+
"Q36: [965, 980]", "Q2: [76, 101)", "Q23: [592, 603)", "Q6: [206, 241)",
402+
"Q1: [23, 76)", "Q10: [323, 354)", "Q4: [133, 182)", "Q8: [280, 312)"]
403+
404+
@test cut([0, 1, 1, 1, 1], 2) ==
405+
["Q1: [0, 1)", "Q2: [1, 1]", "Q2: [1, 1]", "Q2: [1, 1]", "Q2: [1, 1]"]
406+
407+
@test cut([1, 1, 1, 1, 2], 2) ==
408+
["Q1: [1, 2)", "Q1: [1, 2)", "Q1: [1, 2)", "Q1: [1, 2)", "Q2: [2, 2]"]
409+
end
410+
364411
end

0 commit comments

Comments
 (0)