Skip to content

Commit 7b577d7

Browse files
committed
✨ Fix frequency encoder output types
1 parent 14a5671 commit 7b577d7

File tree

5 files changed

+81
-22
lines changed

5 files changed

+81
-22
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ meh/*.ipynb
2727
.DS_Store
2828
/*.jl
2929
scratchpad/
30+
examples/test.jl

src/encoders/frequency_encoding/frequency_encoding.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ function frequency_encoder_fit(
2828
# 1. Define feature mapper
2929
function feature_mapper(col, name)
3030
frequency_map = (!normalize) ? countmap(col) : proportionmap(col)
31-
statistic_given_feat_val = Dict{Any, Real}(level=>frequency_map[level] for level in levels(col))
31+
feat_levels = levels(col)
32+
statistic_given_feat_val = Dict{eltype(feat_levels), Float32}(
33+
level => frequency_map[level] for level in feat_levels
34+
)
3235
return statistic_given_feat_val
3336
end
3437

src/generic.jl

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ function generic_fit(X,
4949
feat_col = Tables.getcolumn(X, feat_name)
5050
feat_type = elscitype(feat_col)
5151
feat_has_allowed_type =
52-
feat_type <: Union{Missing, Multiclass} || (ordered_factor && feat_type <: Union{Missing, OrderedFactor})
52+
feat_type <: Union{Missing, Multiclass} ||
53+
(ordered_factor && feat_type <: Union{Missing, OrderedFactor})
5354
if feat_has_allowed_type # then should be encoded
5455
push!(encoded_features, feat_name)
5556
# Compute the dict using the given feature_mapper function
56-
mapping_per_feat_level[feat_name] = feature_mapper(feat_col, feat_name, args...; kwargs...)
57+
mapping_per_feat_level[feat_name] =
58+
feature_mapper(feat_col, feat_name, args...; kwargs...)
5759
end
5860
end
5961
return mapping_per_feat_level, encoded_features
@@ -72,7 +74,7 @@ function generate_new_feat_names(feat_name, num_inds, existing_names)
7274

7375
new_column_names = []
7476
while conflict
75-
suffix = repeat("_", count)
77+
suffix = repeat("_", count)
7678
new_column_names = [Symbol("$(feat_name)$(suffix)$i") for i in 1:num_inds]
7779
conflict = any(name -> name in existing_names, new_column_names)
7880
count += 1
@@ -85,22 +87,29 @@ end
8587
"""
8688
**Private method.**
8789
88-
Given a table `X` and a dictionary `mapping_per_feat_level` which maps each level for each column in
90+
Given a table `X` and a dictionary `mapping_per_feat_level` which maps each level for each column in
8991
a subset of categorical features of X into a scalar or a vector (as specified in single_feat)
9092
91-
- transforms each value (some level) in each column in `X` using the function in `mapping_per_feat_level`
92-
into a scalar (single_feat=true)
93+
- transforms each value (some level) in each column in `X` using the function in `mapping_per_feat_level`
94+
into a scalar (single_feat=true)
9395
94-
- transforms each value (some level) in each column in `X` using the function in `mapping_per_feat_level`
95-
into a set of k features where k is the length of the vector (single_feat=false)
96+
- transforms each value (some level) in each column in `X` using the function in `mapping_per_feat_level`
97+
into a set of k features where k is the length of the vector (single_feat=false)
9698
- In both cases it attempts to preserve the type of the table.
9799
- In the latter case, it assumes that all levels under the same category are mapped to vectors of the same length. Such
98-
assumption is necessary because any column in X must correspond to a constant number of features
100+
assumption is necessary because any column in X must correspond to a constant number of features
99101
in the output table (which is equal to k).
100102
- Features not in the dictionary are mapped to themselves (i.e., not changed).
101-
- Levels not in the nested dictionary are mapped to themselves if `identity_map_unknown` is true else raise an error.
103+
- Levels not in the nested dictionary are mapped to themselves if `ignore unknown` is true else raise an error.
104+
- If `ensure_categorical` is true, then any input categorical column will remain categorical
102105
"""
103-
function generic_transform(X, mapping_per_feat_level; single_feat = true, ignore_unknown = false)
106+
function generic_transform(
107+
X,
108+
mapping_per_feat_level;
109+
single_feat = true,
110+
ignore_unknown = false,
111+
ensure_categorical = false,
112+
)
104113
feat_names = Tables.schema(X).names
105114
new_feat_names = Symbol[]
106115
new_cols = []
@@ -115,18 +124,25 @@ function generic_transform(X, mapping_per_feat_level; single_feat = true, ignore
115124
if !issubset(test_levels, train_levels)
116125
# get the levels in test that are not in train
117126
lost_levels = setdiff(test_levels, train_levels)
118-
error("While transforming, found novel levels for the column $(feat_name): $(lost_levels) that were not seen while training.")
127+
error(
128+
"While transforming, found novel levels for the column $(feat_name): $(lost_levels) that were not seen while training.",
129+
)
119130
end
120131
end
121-
132+
122133
if single_feat
123134
level2scalar = mapping_per_feat_level[feat_name]
124-
new_col = !isempty(level2scalar) ? recode(col, level2scalar...) : col
135+
if ensure_categorical
136+
new_col = !isempty(level2scalar) ? recode(col, level2scalar...) : col
137+
else
138+
new_col = !isempty(level2scalar) ? unwrap.(recode(col, level2scalar...)) : col
139+
end
140+
125141
push!(new_cols, new_col)
126142
push!(new_feat_names, feat_name)
127143
else
128144
level2vector = mapping_per_feat_level[feat_name]
129-
new_multi_col = map(x->get(level2vector, x, x), col)
145+
new_multi_col = map(x -> get(level2vector, x, x), col)
130146
new_multi_col = [col for col in eachrow(hcat(new_multi_col...))]
131147
push!(new_cols, new_multi_col...)
132148

@@ -144,8 +160,8 @@ function generic_transform(X, mapping_per_feat_level; single_feat = true, ignore
144160
end
145161
end
146162

147-
transformed_X= NamedTuple{tuple(new_feat_names...)}(tuple(new_cols)...)
163+
transformed_X = NamedTuple{tuple(new_feat_names...)}(tuple(new_cols)...)
148164
# Attempt to preserve table type
149165
transformed_X = Tables.materializer(X)(transformed_X)
150166
return transformed_X
151-
end
167+
end

test/encoders/frequency_encoder.jl

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using MLJTransforms: frequency_encoder_fit, frequency_encoder_transform
99
for norm in normalize
1010
result = frequency_encoder_fit(X; normalize = norm)[:statistic_given_feat_val]
1111
enc =
12-
(col, level) -> ((norm) ? sum(col .== level) / length(col) : sum(col .== level))
12+
(col, level) ->
13+
Float32((norm) ? sum(col .== level) / length(col) : sum(col .== level))
1314
true_output = Dict{Symbol, Dict{Any, Any}}(
1415
:F => Dict(
1516
"m" => enc(F_col, "m"),
@@ -44,7 +45,7 @@ end
4445
X_tr = frequency_encoder_transform(X, cache)
4546
enc =
4647
(col, level) ->
47-
((norm) ? sum(X[col] .== level) / length(X[col]) : sum(X[col] .== level))
48+
Float32((norm) ? sum(X[col] .== level) / length(X[col]) : sum(X[col] .== level))
4849

4950
target = (
5051
A = [enc(:A, X[:A][i]) for i in 1:10],
@@ -81,4 +82,42 @@ end
8182
# Test report
8283
@test report(mach) == (encoded_features = generic_cache[:encoded_features],)
8384
end
84-
end
85+
end
86+
87+
@testset "Test Frequency Encoding Output Types" begin
88+
# Define categorical features
89+
A = ["g", "b", "g", "r", "r"]
90+
B = [1.0, 2.0, 3.0, 4.0, 5.0]
91+
C = ["f", "f", "f", "m", "f"]
92+
D = [true, false, true, false, true]
93+
E = [1, 2, 3, 4, 5]
94+
95+
# Combine into a named tuple
96+
X = (A = A, B = B, C = C, D = D, E = E)
97+
98+
# Coerce A, C, D to multiclass and B to continuous and E to ordinal
99+
X = coerce(X,
100+
:A => Multiclass,
101+
:B => Continuous,
102+
:C => Multiclass,
103+
:D => Multiclass,
104+
:E => OrderedFactor,
105+
)
106+
107+
# Check scitype coercions:
108+
schema(X)
109+
110+
encoder = FrequencyEncoder(ordered_factor = false, normalize = false)
111+
mach = fit!(machine(encoder, X))
112+
Xnew = MMI.transform(mach, X)
113+
114+
115+
scs = schema(Xnew).scitypes
116+
ts = schema(Xnew).types
117+
# Check scitypes correctness
118+
@test all(scs[1:end-1] .== Continuous)
119+
@test all(t -> (t <: AbstractFloat) && isconcretetype(t), ts[1:end-1])
120+
# Ordinal column should be intact
121+
@test scs[end] === schema(X).scitypes[end]
122+
@test ts[end] == schema(X).types[end]
123+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using StatsModels
1515

1616
# Other transformers
1717
using Tables, CategoricalArrays
18-
using ScientificTypes: scitype
18+
using ScientificTypes: scitype, schema
1919
using Statistics
2020
using StableRNGs
2121
stable_rng = StableRNGs.StableRNG(123)

0 commit comments

Comments
 (0)