Skip to content

Commit 0ede6dc

Browse files
greimelnalimilan
authored andcommitted
Provide formatter for labeling categories in cut function (#202)
1 parent 327bef7 commit 0ede6dc

File tree

2 files changed

+69
-14
lines changed

2 files changed

+69
-14
lines changed

src/extras.jl

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,16 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing},
4242
end
4343

4444
"""
45+
default_formatter(from, to, i; closed=false)
46+
47+
Provide the default label format for the `cut` function.
48+
"""
49+
default_formatter(from, to, i; closed) = string("[", from, ", ", to, closed ? "]" : ")")
50+
51+
@doc raw"""
4552
cut(x::AbstractArray, breaks::AbstractVector;
46-
extend::Bool=false, labels::AbstractVector=[], allow_missing::Bool=false)
53+
labels::Union{AbstractVector{<:AbstractString},Function},
54+
extend::Bool=false, allow_missing::Bool=false)
4755
4856
Cut a numeric array into intervals and return an ordered `CategoricalArray` indicating
4957
the interval into which each entry falls. Intervals are of the form `[lower, upper)`,
@@ -56,14 +64,55 @@ also accept them.
5664
* `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall
5765
outside of the breaks; when `true`, breaks are automatically added to include all
5866
values in `x`, and the upper bound is included in the last interval.
59-
* `labels::AbstractVector=[]`: a vector of strings giving the names to use for the
60-
intervals; if empty, default labels are used.
67+
* `labels::Union{AbstractVector,Function}: a vector of strings giving the names to use for
68+
the intervals; or a function `f(from, to, i; closed)` that generates the labels from the
69+
left and right interval boundaries and the group index. Defaults to
70+
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
6171
* `allow_missing::Bool=true`: when `true`, values outside of breaks result in missing values.
6272
only supported when `x` accepts missing values.
73+
74+
# Examples
75+
```jldoctest
76+
julia> cut(-1:0.5:1, [0, 1], extend=true)
77+
5-element CategoricalArray{String,1,UInt32}:
78+
"[-1.0, 0.0)"
79+
"[-1.0, 0.0)"
80+
"[0.0, 1.0]"
81+
"[0.0, 1.0]"
82+
"[0.0, 1.0]"
83+
84+
julia> cut(-1:0.5:1, 2)
85+
5-element CategoricalArray{String,1,UInt32}:
86+
"[-1.0, 0.0)"
87+
"[-1.0, 0.0)"
88+
"[0.0, 1.0]"
89+
"[0.0, 1.0]"
90+
"[0.0, 1.0]"
91+
92+
julia> cut(-1:0.5:1, 2, labels=["A", "B"])
93+
5-element CategoricalArray{String,1,UInt32}:
94+
"A"
95+
"A"
96+
"B"
97+
"B"
98+
"B"
99+
100+
julia> fmt(from, to, i; closed) = "grp $i ($from//$to)"
101+
fmt (generic function with 1 method)
102+
103+
julia> cut(-1:0.5:1, 3, labels=fmt)
104+
5-element CategoricalArray{String,1,UInt32}:
105+
"grp 1 (-1.0//-0.333333)"
106+
"grp 1 (-1.0//-0.333333)"
107+
"grp 2 (-0.333333//0.333333)"
108+
"grp 3 (0.333333//1.0)"
109+
"grp 3 (0.333333//1.0)"
110+
```
63111
"""
64112
function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
65-
extend::Bool=false, labels::AbstractVector{U}=String[],
66-
allow_missing::Bool=false) where {T, N, U<:AbstractString}
113+
extend::Bool=false,
114+
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter,
115+
allow_missing::Bool=false) where {T, N}
67116
if !issorted(breaks)
68117
breaks = sort(breaks)
69118
end
@@ -92,7 +141,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
92141
end
93142

94143
n = length(breaks)
95-
if isempty(labels)
144+
if labels isa Function
96145
@static if VERSION >= v"0.7.0-DEV.4524"
97146
from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1])
98147
to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n])
@@ -102,13 +151,9 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
102151
end
103152
levs = Vector{String}(undef, n-1)
104153
for i in 1:n-2
105-
levs[i] = string("[", from[i], ", ", to[i], ")")
106-
end
107-
if extend
108-
levs[end] = string("[", from[end], ", ", to[end], "]")
109-
else
110-
levs[end] = string("[", from[end], ", ", to[end], ")")
154+
levs[i] = labels(from[i], to[i], i, closed=false)
111155
end
156+
levs[end] = labels(from[end], to[end], n-1, closed=extend)
112157
else
113158
length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
114159
# Levels must have element type String for type stability of the result
@@ -122,11 +167,11 @@ end
122167

123168
"""
124169
cut(x::AbstractArray, ngroups::Integer;
125-
labels::AbstractVector=String[])
170+
labels::Union{AbstractVector{<:AbstractString},Function})
126171
127172
Cut a numeric array into `ngroups` quantiles, determined using
128173
[`quantile`](@ref).
129174
"""
130175
cut(x::AbstractArray, ngroups::Integer;
131-
labels::AbstractVector{U}=String[]) where {U<:AbstractString} =
176+
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter) =
132177
cut(x, Statistics.quantile(x, (1:ngroups-1)/ngroups); extend=true, labels=labels)

test/15_extras.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,14 @@ end
108108
@test levels(x) == ["[2.0, 3.5)", "[3.5, 5.0]"]
109109
end
110110

111+
@testset "cut with formatter function" begin
112+
my_formatter(from, to, i; closed) = "$i: $from -- $to"
113+
114+
x = 0.15:0.20:0.95
115+
p = [0, 0.4, 0.8, 1.0]
116+
117+
@test cut(x, p, labels=my_formatter) ==
118+
["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"]
119+
end
120+
111121
end

0 commit comments

Comments
 (0)