Skip to content

Commit 7badc9e

Browse files
authored
Make levels return a CategoricalArray (#425)
Having `levels` preserve the eltype of the input is sometimes useful to write generic code. This is only slightly breaking as the result still compares equal to the previous behavior returning unwrapped values.
1 parent 13a9bad commit 7badc9e

15 files changed

+135
-89
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ SUITE["many levels"]["CategoricalArray(::Vector{String})"] =
5555
a = rand([@sprintf("id%010d", k) for k in 1:1000], 10000)
5656
ca = CategoricalArray(a)
5757

58-
levs = levels(ca)
58+
levs = unwrap.(levels(ca))
5959
SUITE["many levels"]["levels! with original levels"] =
6060
@benchmarkable levels!(ca, levs)
6161

62-
levs = reverse(levels(ca))
62+
levs = reverse(unwrap.(levels(ca)))
6363
SUITE["many levels"]["levels! with resorted levels"] =
6464
@benchmarkable levels!(ca, levs)
6565

docs/src/using.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ By default, the levels are lexically sorted, which is clearly not correct in our
2020

2121
```jldoctest using
2222
julia> levels(x)
23-
3-element Vector{String}:
23+
3-element CategoricalArray{String,1,UInt32}:
2424
"Middle"
2525
"Old"
2626
"Young"
@@ -68,7 +68,7 @@ To get rid of the `"Old"` group, just call the [`droplevels!`](@ref) function:
6868

6969
```jldoctest using
7070
julia> levels(x)
71-
3-element Vector{String}:
71+
3-element CategoricalArray{String,1,UInt32}:
7272
"Young"
7373
"Middle"
7474
"Old"
@@ -81,7 +81,7 @@ julia> droplevels!(x)
8181
"Young"
8282
8383
julia> levels(x)
84-
2-element Vector{String}:
84+
2-element CategoricalArray{String,1,UInt32}:
8585
"Young"
8686
"Middle"
8787
@@ -139,7 +139,7 @@ Levels still need to be reordered manually:
139139

140140
```jldoctest using
141141
julia> levels(y)
142-
3-element Vector{String}:
142+
3-element CategoricalArray{String,1,UInt32}:
143143
"Middle"
144144
"Old"
145145
"Young"
@@ -251,7 +251,7 @@ julia> xy = vcat(x, y)
251251
"Middle"
252252
253253
julia> levels(xy)
254-
3-element Vector{String}:
254+
3-element CategoricalArray{String,1,UInt32}:
255255
"Young"
256256
"Middle"
257257
"Old"
@@ -263,15 +263,15 @@ true
263263
Likewise, assigning a `CategoricalValue` from `y` to an entry in `x` expands the levels of `x` with all levels from `y`, *respecting the ordering of levels of both vectors if possible*:
264264
```jldoctest using
265265
julia> levels(x)
266-
2-element Vector{String}:
266+
2-element CategoricalArray{String,1,UInt32}:
267267
"Middle"
268268
"Old"
269269
270270
julia> x[1] = y[1]
271271
CategoricalValue{String, UInt32} "Young" (1/2)
272272
273273
julia> levels(x)
274-
3-element Vector{String}:
274+
3-element CategoricalArray{String,1,UInt32}:
275275
"Young"
276276
"Middle"
277277
"Old"
@@ -296,7 +296,7 @@ julia> ab = vcat(a, b)
296296
"c"
297297
298298
julia> levels(ab)
299-
3-element Vector{String}:
299+
3-element CategoricalArray{String,1,UInt32}:
300300
"a"
301301
"b"
302302
"c"
@@ -320,7 +320,7 @@ julia> ab2 = vcat(a, b)
320320
"c"
321321
322322
julia> levels(ab2)
323-
3-element Vector{String}:
323+
3-element CategoricalArray{String,1,UInt32}:
324324
"a"
325325
"b"
326326
"c"

ext/CategoricalArraysArrowExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import Arrow: ArrowTypes
77
const CATARRAY_ARROWNAME = Symbol("JuliaLang.CategoricalArrays.CategoricalArray")
88
ArrowTypes.arrowname(::Type{<:CategoricalValue}) = CATARRAY_ARROWNAME
99
ArrowTypes.arrowmetadata(::Type{CategoricalValue{T, R}}) where {T, R} = string(R)
10+
ArrowTypes.ArrowType(::Type{<:CategoricalValue{T}}) where {T} = T
11+
ArrowTypes.toarrow(x::CategoricalValue) = unwrap(x)
1012

1113
ArrowTypes.arrowname(::Type{Union{<:CategoricalValue, Missing}}) = CATARRAY_ARROWNAME
1214
ArrowTypes.arrowmetadata(::Type{Union{CategoricalValue{T, R}, Missing}}) where {T, R} =

ext/CategoricalArraysRecipesBaseExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ else
99
end
1010

1111
RecipesBase.@recipe function f(::Type{T}, v::T) where T <: CategoricalValue
12-
level_strings = [map(string, levels(v)); missing]
12+
level_strings = [map(string, CategoricalArrays._levels(v)); missing]
1313
ticks --> eachindex(level_strings)
1414
v -> ismissing(v) ? length(level_strings) : Int(CategoricalArrays.refcode(v)),
1515
i -> level_strings[Int(i)]

src/array.jl

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ function CategoricalArray{T, N, R}(A::CategoricalArray{S, N, Q};
240240
catch err
241241
err isa LevelsException || rethrow(err)
242242
throw(ArgumentError("encountered value(s) not in specified `levels`: " *
243-
"$(setdiff(CategoricalArrays.levels(res), levels))"))
243+
"$(setdiff(_levels(res), levels))"))
244244
end
245245
end
246246
return res
@@ -359,18 +359,18 @@ function _convert(::Type{CategoricalArray{T, N, R}}, A::AbstractArray{S, N};
359359
copyto!(res, A)
360360

361361
if levels !== nothing
362-
CategoricalArrays.levels(res) == levels ||
362+
_levels(res) == levels ||
363363
throw(ArgumentError("encountered value(s) not in specified `levels`: " *
364-
"$(setdiff(CategoricalArrays.levels(res), levels))"))
364+
"$(setdiff(_levels(res), levels))"))
365365
else
366366
# if order is defined for level type, automatically apply it
367367
L = leveltype(res)
368368
if Base.OrderStyle(L) isa Base.Ordered
369-
levels!(res, sort(CategoricalArrays.levels(res)))
369+
levels!(res, sort(_levels(res)))
370370
elseif hasmethod(isless, (L, L))
371371
# isless may throw an error, e.g. for AbstractArray{T} of unordered T
372372
try
373-
levels!(res, sort(CategoricalArrays.levels(res)))
373+
levels!(res, sort(_levels(res)))
374374
catch e
375375
e isa MethodError || rethrow(e)
376376
end
@@ -383,7 +383,7 @@ end
383383
# From CategoricalArray (preserve levels, ordering and R)
384384
function convert(::Type{CategoricalArray{T, N, R}}, A::CategoricalArray{S, N}) where {S, T, N, R}
385385
if length(A.pool) > typemax(R)
386-
throw(LevelsException{T, R}(levels(A)[typemax(R)+1:end]))
386+
throw(LevelsException{T, R}(_levels(A)[typemax(R)+1:end]))
387387
end
388388

389389
if !(T >: Missing) && S >: Missing && any(iszero, A.refs)
@@ -467,7 +467,7 @@ size(A::CategoricalArray) = size(A.refs)
467467
Base.IndexStyle(::Type{<:CategoricalArray}) = IndexLinear()
468468

469469
function update_refs!(A::CategoricalArray, newlevels::AbstractVector)
470-
oldlevels = levels(A)
470+
oldlevels = _levels(A)
471471
levelsmap = similar(A.refs, length(oldlevels)+1)
472472
# 0 maps to a missing value
473473
levelsmap[1] = 0
@@ -485,7 +485,7 @@ function merge_pools!(A::CatArrOrSub,
485485
updaterefs::Bool=true,
486486
updatepool::Bool=true)
487487
newlevels, ordered = merge_pools(pool(A), pool(B))
488-
oldlevels = levels(A)
488+
oldlevels = _levels(A)
489489
pA = A isa SubArray ? parent(A) : A
490490
ordered!(pA, ordered)
491491
# If A's levels are an ordered superset of new (merged) pool, no need to recompute refs
@@ -544,8 +544,8 @@ function copyto!(dest::CatArrOrSub{T, N, R}, dstart::Integer,
544544

545545
# try converting src to dest type to avoid partial copy corruption of dest
546546
# in the event that the src cannot be copied into dest
547-
slevs = convert(Vector{T}, levels(src))
548-
dlevs = levels(dest)
547+
slevs = convert(Vector{T}, _levels(src))
548+
dlevs = _levels(dest)
549549
if eltype(src) >: Missing && !(eltype(dest) >: Missing) && !all(x -> x > 0, srefs)
550550
throw(MissingException("cannot copy array with missing values to an array with element type $T"))
551551
end
@@ -598,7 +598,7 @@ function copyto!(dest::CatArrOrSub{T1, N, R}, dstart::Integer,
598598
return invoke(copyto!, Tuple{AbstractArray, Integer, AbstractArray, Integer, Integer},
599599
dest, dstart, src, sstart, n)
600600
end
601-
newdestlevs = destlevs = copy(levels(dest)) # copy since we need original levels below
601+
newdestlevs = destlevs = copy(_levels(dest)) # copy since we need original levels below
602602
srclevsnm = T2 >: Missing ? setdiff(srclevs, [missing]) : srclevs
603603
if !(srclevsnm destlevs)
604604
# if order is defined for level type, automatically apply it
@@ -708,7 +708,7 @@ While this will reduce memory use, this function is type-unstable, which can aff
708708
performance inside the function where the call is made. Therefore, use it with caution.
709709
"""
710710
function compress(A::CategoricalArray{T, N}) where {T, N}
711-
R = reftype(length(levels(A.pool)))
711+
R = reftype(length(_levels(A.pool)))
712712
convert(CategoricalArray{T, N, R}, A)
713713
end
714714

@@ -726,11 +726,11 @@ decompress(A::CategoricalArray{T, N}) where {T, N} =
726726
convert(CategoricalArray{T, N, DefaultRefType}, A)
727727

728728
function vcat(A::CategoricalArray...)
729-
ordered = any(isordered, A) && all(a->isordered(a) || isempty(levels(a)), A)
730-
newlevels, ordered = mergelevels(ordered, map(levels, A)...)
729+
ordered = any(isordered, A) && all(a->isordered(a) || isempty(_levels(a)), A)
730+
newlevels, ordered = mergelevels(ordered, map(_levels, A)...)
731731

732732
refsvec = map(A) do a
733-
ii = convert(Vector{Int}, indexin(levels(a.pool), newlevels))
733+
ii = convert(Vector{Int}, indexin(_levels(a.pool), newlevels))
734734
[x==0 ? 0 : ii[x] for x in a.refs]::Array{Int,ndims(a)}
735735
end
736736

@@ -768,23 +768,25 @@ This may include levels which do not actually appear in the data
768768
`missing` will be included only if it appears in the data and
769769
`skipmissing=false` is passed.
770770
771-
The returned vector is an internal field of `x` which must not be mutated
771+
The returned vector is owned by `x` and must not be mutated
772772
as doing so would corrupt it.
773773
"""
774-
@inline function DataAPI.levels(A::CatArrOrSub{T}; skipmissing::Bool=true) where T
774+
@inline function DataAPI.levels(A::CatArrOrSub; skipmissing::Bool=true)
775775
if eltype(A) >: Missing && !skipmissing
776776
if any(==(0), refs(A))
777-
T[levels(pool(A)); missing]
777+
eltype(A)[levels(pool(A)); missing]
778778
else
779-
convert(Vector{T}, levels(pool(A)))
779+
levels_missing(pool(A))
780780
end
781781
else
782782
levels(pool(A))
783783
end
784784
end
785785

786+
_levels(A::CatArrOrSub) = _levels(pool(A))
787+
786788
"""
787-
levels!(A::CategoricalArray, newlevels::Vector; allowmissing::Bool=false)
789+
levels!(A::CategoricalArray, newlevels::AbstractVector; allowmissing::Bool=false)
788790
789791
Set the levels categorical array `A`. The order of appearance of levels will be respected
790792
by [`levels`](@ref DataAPI.levels), which may affect display of results in some operations; if `A` is
@@ -798,7 +800,7 @@ Else, `newlevels` must include all levels which appear in the data.
798800
"""
799801
function levels!(A::CategoricalArray{T, N, R}, newlevels::AbstractVector;
800802
allowmissing::Bool=false) where {T, N, R}
801-
(levels(A) == newlevels) && return A # nothing to do
803+
(_levels(A) == newlevels) && return A # nothing to do
802804

803805
# map each new level to its ref code
804806
newlv2ref = Dict{eltype(newlevels), Int}()
@@ -813,7 +815,7 @@ function levels!(A::CategoricalArray{T, N, R}, newlevels::AbstractVector;
813815
end
814816

815817
# map each old ref code to new ref code (or 0 if no such level)
816-
oldlevels = levels(pool(A))
818+
oldlevels = _levels(pool(A))
817819
oldref2newref = fill(0, length(oldlevels) + 1)
818820
for (i, lv) in enumerate(oldlevels)
819821
oldref2newref[i + 1] = get(newlv2ref, lv, 0)
@@ -874,7 +876,7 @@ end
874876
function _uniquerefs(A::CatArrOrSub{T}) where T
875877
arefs = refs(A)
876878
res = similar(arefs, 0)
877-
nlevels = length(levels(A))
879+
nlevels = length(_levels(A))
878880
maxunique = nlevels + (T >: Missing ? 1 : 0)
879881
seen = fill(false, nlevels + 1) # always +1 for 0 (missing ref)
880882
@inbounds for ref in arefs
@@ -907,7 +909,7 @@ returned by [`levels`](@ref DataAPI.levels)).
907909
"""
908910
function droplevels!(A::CategoricalArray)
909911
arefs = refs(A)
910-
nlevels = length(levels(A)) + 1 # +1 for missing
912+
nlevels = length(_levels(A)) + 1 # +1 for missing
911913
seen = fill(false, nlevels)
912914
seen[1] = true # assume that missing is always observed to simplify checks
913915
nseen = 1
@@ -920,7 +922,7 @@ function droplevels!(A::CategoricalArray)
920922
end
921923

922924
# replace the pool
923-
A.pool = typeof(pool(A))(@inbounds(levels(A)[view(seen, 2:nlevels)]), isordered(A))
925+
A.pool = typeof(pool(A))(@inbounds(_levels(A)[view(seen, 2:nlevels)]), isordered(A))
924926
# recode refs to keep only the seen ones (optimized version of update_refs!())
925927
seen[1] = false # to start levelsmap from 0
926928
levelsmap = cumsum(seen)
@@ -1037,7 +1039,7 @@ end
10371039
ordered=_isordered(A),
10381040
compress::Bool=false) where {T, N, R}
10391041
# @inline is needed so that return type is inferred when compress is not provided
1040-
RefType = compress ? reftype(length(CategoricalArrays.levels(A))) : R
1042+
RefType = compress ? reftype(length(_levels(A))) : R
10411043
CategoricalArray{T, N, RefType}(A, levels=levels, ordered=ordered)
10421044
end
10431045

@@ -1050,7 +1052,7 @@ function in(x::CategoricalValue, y::CategoricalArray{T, N, R}) where {T, N, R}
10501052
if x.pool === y.pool
10511053
return refcode(x) in y.refs
10521054
else
1053-
ref = get(y.pool, levels(x.pool)[refcode(x)], zero(R))
1055+
ref = get(y.pool, _levels(x.pool)[refcode(x)], zero(R))
10541056
return ref != 0 ? ref in y.refs : false
10551057
end
10561058
end

0 commit comments

Comments
 (0)