Skip to content

Commit 6c4589b

Browse files
committed
✨ Fix cardinality reducer output types
1 parent e9a0c44 commit 6c4589b

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

src/transformers/cardinality_reducer/cardinality_reducer.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,20 @@ function cardinality_reducer_fit(
3535
features::AbstractVector{Symbol} = Symbol[];
3636
ignore::Bool = true,
3737
ordered_factor::Bool = false,
38-
min_frequency::Real = 3,
39-
label_for_infrequent::Dict{<:Type, <:Any} = Dict(
38+
min_frequency::Real = 3,
39+
label_for_infrequent::Dict{<:Type, <:Any} = Dict(
4040
AbstractString => "Other",
4141
Char => 'O',
4242
),
43-
)
43+
)
4444
supportedtypes_list = [Char, AbstractString, Number]
4545
supportedtypes = Union{supportedtypes_list...}
4646

4747
# 1. Define feature mapper
4848
function feature_mapper(col, name)
4949
val_to_freq = (min_frequency isa AbstractFloat) ? proportionmap(col) : countmap(col)
50-
col_type = eltype(col).parameters[1]
5150
feat_levels = levels(col)
51+
col_type = eltype(feat_levels)
5252

5353
# Ensure column type is valid (can't test because never occurs)
5454
# Converting array elements to strings before wrapping in a `CategoricalArray`, as...
@@ -88,7 +88,11 @@ function cardinality_reducer_fit(
8888
elseif elgrandtype == Number
8989
new_cat_given_col_val[level] = minimum(feat_levels) - 1
9090
else
91-
throw(ArgumentError(UNSPECIFIED_COL_TYPE(col_type, label_for_infrequent)))
91+
throw(
92+
ArgumentError(
93+
UNSPECIFIED_COL_TYPE(col_type, label_for_infrequent),
94+
),
95+
)
9296
end
9397
end
9498
end
@@ -98,7 +102,8 @@ function cardinality_reducer_fit(
98102

99103
# 2. Pass it to generic_fit
100104
new_cat_given_col_val, encoded_features = generic_fit(
101-
X, features; ignore = ignore, ordered_factor = ordered_factor, feature_mapper = feature_mapper,
105+
X, features; ignore = ignore, ordered_factor = ordered_factor,
106+
feature_mapper = feature_mapper,
102107
)
103108
cache = Dict(
104109
:new_cat_given_col_val => new_cat_given_col_val,
@@ -125,5 +130,5 @@ Apply a fitted cardinality reducer to a table given the output of `cardinality_r
125130
"""
126131
function cardinality_reducer_transform(X, cache::Dict)
127132
new_cat_given_col_val = cache[:new_cat_given_col_val]
128-
return generic_transform(X, new_cat_given_col_val; ignore_unknown = true)
133+
return generic_transform(X, new_cat_given_col_val; ignore_unknown = true, ensure_categorical = true)
129134
end

test/transformers/cardinality_reducer.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,5 +190,30 @@ end
190190
@test report(mach) == (encoded_features = generic_cache[:encoded_features],)
191191
end
192192

193+
194+
@testset "Test Cardinality Reducer Output Types" begin
195+
# Define categorical features
196+
A = [["a" for i in 1:100]..., "b", "b", "b", "c", "d"]
197+
B = [[0 for i in 1:100]..., 1, 2, 3, 4, 4]
198+
199+
# Combine into a named tuple
200+
X = (A = A, B = B)
201+
202+
# Coerce A, C, D to multiclass and B to continuous and E to ordinal
203+
X = coerce(X,
204+
:A => Multiclass,
205+
:B => Multiclass,
206+
)
207+
208+
levels(X.A)
209+
210+
encoder = CardinalityReducer(ordered_factor = false, min_frequency = 3)
211+
mach = fit!(machine(encoder, X))
212+
Xnew = MMI.transform(mach, X)
213+
@test schema(X).types == schema(Xnew).types
214+
@test all(s -> (s <: Multiclass), schema(Xnew).scitypes)
215+
end
216+
217+
193218
# Look into MLJModelInterfaceTest
194219
# Add tests to ensure categorical feature properties are as expected

0 commit comments

Comments
 (0)