Skip to content

Commit 3e0d056

Browse files
Make recode! type stable (#407)
Varargs appear to be type-stable according to `@code_warntype` but in practice that's not the case.
1 parent 341de70 commit 3e0d056

File tree

3 files changed

+43
-44
lines changed

3 files changed

+43
-44
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
33
version = "0.10.8"
44

55
[deps]
6+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
67
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
78
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
89
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
@@ -24,6 +25,7 @@ CategoricalArraysSentinelArraysExt = "SentinelArrays"
2425
CategoricalArraysStructTypesExt = "StructTypes"
2526

2627
[compat]
28+
Compat = "3.37, 4"
2729
DataAPI = "1.6"
2830
JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21"
2931
JSON3 = "1.1.2"

src/CategoricalArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module CategoricalArrays
1414
using DataAPI
1515
using Missings
1616
using Printf
17+
import Compat
1718

1819
# JuliaLang/julia#36810
1920
if VERSION < v"1.5.2"

src/recode.jl

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,34 @@ A user defined type could override this method to define an appropriate test fun
5252
optimize_pair(pair::Pair) = pair
5353
optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second
5454

55-
function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
55+
function recode!(dest::AbstractArray, src::AbstractArray, default::Any, pairs::Pair...)
5656
if length(dest) != length(src)
5757
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
5858
end
5959

60-
opt_pairs = map(optimize_pair, pairs)
60+
opt_pairs = optimize_pair.(pairs)
6161

62+
_recode!(dest, src, default, opt_pairs)
63+
end
64+
65+
function _recode!(dest::AbstractArray{T}, src::AbstractArray, default,
66+
pairs::NTuple{<:Any, Pair}) where {T}
67+
recode_to = last.(pairs)
68+
recode_from = first.(pairs)
69+
6270
@inbounds for i in eachindex(dest, src)
6371
x = src[i]
6472

65-
for j in 1:length(opt_pairs)
66-
p = opt_pairs[j]
67-
# we use isequal and recode_in because we cannot really distinguish scalars from collections
68-
if x p.first || recode_in(x, p.first)
69-
dest[i] = p.second
70-
@goto nextitem
71-
end
72-
end
73-
73+
# @inline is needed for type stability and Compat for compatibility before julia v1.8
74+
# we use isequal and recode_in because we cannot really
75+
# distinguish scalars from collections
76+
j = Compat.@inline findfirst(y -> isequal(x, y) || recode_in(x,y), recode_from)
77+
78+
# Value in one of the pairs
79+
if j !== nothing
80+
dest[i] = recode_to[j]
7481
# Value not in any of the pairs
75-
if ismissing(x)
82+
elseif ismissing(x)
7683
eltype(dest) >: Missing ||
7784
throw(MissingException("missing value found, but dest does not support them: " *
7885
"recode them to a supported value"))
@@ -89,21 +96,16 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs
8996
else
9097
dest[i] = default
9198
end
92-
93-
@label nextitem
9499
end
95100

96101
dest
97102
end
98103

99-
function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
100-
if length(dest) != length(src)
101-
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
102-
end
103-
104-
opt_pairs = map(optimize_pair, pairs)
104+
function _recode!(dest::CategoricalArray{T, <:Any, R}, src::AbstractArray, default::Any,
105+
pairs::NTuple{<:Any, Pair}) where {T, R}
106+
recode_from = first.(pairs)
107+
vals = T[p.second for p in pairs]
105108

106-
vals = T[p.second for p in opt_pairs]
107109
default !== nothing && push!(vals, default)
108110

109111
levels!(dest.pool, filter!(!ismissing, unique(vals)))
@@ -112,22 +114,22 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
112114
dupvals = length(vals) != length(levels(dest.pool))
113115

114116
drefs = dest.refs
115-
pairmap = [ismissing(v) ? 0 : get(dest.pool, v) for v in vals]
116-
defaultref = default === nothing || ismissing(default) ? 0 : get(dest.pool, default)
117+
pairmap = [ismissing(v) ? zero(R) : get(dest.pool, v) for v in vals]
118+
defaultref = default === nothing || ismissing(default) ? zero(R) : get(dest.pool, default)
119+
117120
@inbounds for i in eachindex(drefs, src)
118121
x = src[i]
119122

120-
for j in 1:length(opt_pairs)
121-
p = opt_pairs[j]
122-
# we use isequal and recode_in because we cannot really distinguish scalars from collections
123-
if x p.first || recode_in(x, p.first)
124-
drefs[i] = dupvals ? pairmap[j] : j
125-
@goto nextitem
126-
end
127-
end
123+
# @inline is needed for type stability and Compat for compatibility before julia v1.8
124+
# we use isequal and recode_in because we cannot really
125+
# distinguish scalars from collections
126+
j = Compat.@inline findfirst(y -> isequal(x, y) || recode_in(x, y), recode_from)
128127

128+
# Value in one of the pairs
129+
if j !== nothing
130+
drefs[i] = dupvals ? pairmap[j] : j
129131
# Value not in any of the pairs
130-
if ismissing(x)
132+
elseif ismissing(x)
131133
eltype(dest) >: Missing ||
132134
throw(MissingException("missing value found, but dest does not support them: " *
133135
"recode them to a supported value"))
@@ -144,8 +146,6 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
144146
else
145147
drefs[i] = defaultref
146148
end
147-
148-
@label nextitem
149149
end
150150

151151
# Put existing levels first, and sort them if possible
@@ -168,25 +168,21 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
168168
dest
169169
end
170170

171-
function recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray,
172-
default::Any, pairs::Pair...) where {T, N, R<:Integer}
173-
if length(dest) != length(src)
174-
throw(DimensionMismatch("dest and src must be of the same length " *
175-
"(got $(length(dest)) and $(length(src)))"))
176-
end
177-
171+
function _recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray,
172+
default::Any, pairs::NTuple{<:Any, Pair}) where {T, N, R<:Integer}
173+
recode_from = first.(pairs)
178174
vals = T[p.second for p in pairs]
175+
179176
if default === nothing
180177
srclevels = levels(src)
181178

182179
# Remove recoded levels as they won't appear in result
183-
firsts = (p.first for p in pairs)
184180
keptlevels = Vector{T}(undef, 0)
185181
sizehint!(keptlevels, length(srclevels))
186182

187183
for l in srclevels
188-
if !(any(x -> x l, firsts) ||
189-
any(f -> recode_in(l, f), firsts))
184+
if !(any(x -> x l, recode_from) ||
185+
any(f -> recode_in(l, f), recode_from))
190186
try
191187
push!(keptlevels, l)
192188
catch err

0 commit comments

Comments
 (0)