Skip to content

Commit d1a3461

Browse files
committed
Use labels kw, group index, add tests
1 parent bc50544 commit d1a3461

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

src/extras.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing},
4141
end
4242
end
4343

44+
"""
45+
_default_formatter_
46+
47+
Provide the default label format for the `cut` function.
48+
"""
49+
_default_formatter_(from, to, i; extend=false) = string("[", from, ", ", to, extend ? "]" : ")")
50+
4451
"""
4552
cut(x::AbstractArray, breaks::AbstractVector;
4653
extend::Bool=false, labels::AbstractVector=[], allow_missing::Bool=false)
@@ -56,16 +63,16 @@ also accept them.
5663
* `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall
5764
outside of the breaks; when `true`, breaks are automatically added to include all
5865
values in `x`, and the upper bound is included in the last interval.
59-
* `labels::AbstractVector=[]`: a vector of strings giving the names to use for the
60-
intervals; if empty, default labels are used.
61-
* `label_formatter::Function`: a function `f(from,to;extend=false)` that generates the labels from the left and right interval boundaries. Defaults to `string("[", from, ", ", to, extend ? "]" : ")")`, e.g. `"[1, 5)"`.
66+
* `labels::Union{AbstractVector,Function}=_default_formatter_`: a vector of strings giving the names to use for the
67+
intervals; or a function `f(from,to,i;extend=false)` that generates the labels from the left and right interval boundaries and the group index. Defaults to `string("[", from, ", ", to, extend ? "]" : ")")`, e.g. `"[1, 5)"`.
6268
* `allow_missing::Bool=true`: when `true`, values outside of breaks result in missing values.
6369
only supported when `x` accepts missing values.
6470
"""
6571
function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
66-
extend::Bool=false, labels::AbstractVector{U}=String[],
67-
label_formatter=_default_formatter_,
72+
extend::Bool=false, labels=_default_formatter_,
6873
allow_missing::Bool=false) where {T, N, U<:AbstractString}
74+
(labels isa AbstractVector) || (labels isa Function) || throw(ArgumentError("labels must be a formatter function or an AbstractVector"))
75+
6976
if !issorted(breaks)
7077
breaks = sort(breaks)
7178
end
@@ -94,7 +101,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
94101
end
95102

96103
n = length(breaks)
97-
if isempty(labels)
104+
if labels isa Function
98105
@static if VERSION >= v"0.7.0-DEV.4524"
99106
from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1])
100107
to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n])
@@ -104,12 +111,12 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
104111
end
105112
levs = Vector{String}(undef, n-1)
106113
for i in 1:n-2
107-
levs[i] = label_formatter(from[i], to[i])
114+
levs[i] = labels(from[i], to[i], i)
108115
end
109116
if extend
110-
levs[end] = label_formatter(from[end], to[end], extend=extend)
117+
levs[end] = labels(from[end], to[end], n-1, extend=extend)
111118
else
112-
levs[end] = label_formatter(from[end], to[end])
119+
levs[end] = labels(from[end], to[end], n-1)
113120
end
114121
else
115122
length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
@@ -130,8 +137,8 @@ Cut a numeric array into `ngroups` quantiles, determined using
130137
[`quantile`](@ref).
131138
"""
132139
cut(x::AbstractArray, ngroups::Integer;
133-
labels::AbstractVector{U}=String[], label_formatter=_default_formatter_) where {U<:AbstractString} =
134-
cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels, label_formatter=label_formatter)
140+
labels=_default_formatter_) =
141+
cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels)
142+
135143

136-
_default_formatter_(from, to; extend=false) = string("[", from, ", ", to, extend ? "]" : ")")
137144

test/15_extras.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,20 @@ end
108108
@test levels(x) == ["[2.0, 3.5)", "[3.5, 5.0]"]
109109
end
110110

111+
@testset "formatter function" begin
112+
my_formatter1(from,to,i;extend=false) = "group $i"
113+
my_formatter2(from,to,i;extend=false) = "$i: $from -- $to"
114+
function my_formatter3(from,to,i;extend=true)
115+
percentile(x) = Int(round(100 * parse.(Float64,x),digits=0))
116+
string("P",percentile(from),"P",percentile(to))
117+
end
118+
119+
x = collect(0.15:0.20:0.95)
120+
p = [0, 0.4, 0.8, 1.0]
121+
122+
@test cut(x, p, labels=my_formatter1) == ["group 1", "group 1", "group 2", "group 2", "group 3"]
123+
@test cut(x, p, labels=my_formatter2) == ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"]
124+
@test cut(x, p, labels=my_formatter3) == ["P0P40" , "P0P40" , "P40P80" , "P40P80" , "P80P100"]
125+
end
126+
111127
end

0 commit comments

Comments
 (0)