@@ -35,20 +35,20 @@ function cardinality_reducer_fit(
35
35
features:: AbstractVector{Symbol} = Symbol[];
36
36
ignore:: Bool = true ,
37
37
ordered_factor:: Bool = false ,
38
- min_frequency:: Real = 3 ,
39
- label_for_infrequent:: Dict{<:Type, <:Any} = Dict (
38
+ min_frequency:: Real = 3 ,
39
+ label_for_infrequent:: Dict{<:Type, <:Any} = Dict (
40
40
AbstractString => " Other" ,
41
41
Char => ' O' ,
42
42
),
43
- )
43
+ )
44
44
supportedtypes_list = [Char, AbstractString, Number]
45
45
supportedtypes = Union{supportedtypes_list... }
46
46
47
47
# 1. Define feature mapper
48
48
function feature_mapper (col, name)
49
49
val_to_freq = (min_frequency isa AbstractFloat) ? proportionmap (col) : countmap (col)
50
- col_type = eltype (col). parameters[1 ]
51
50
feat_levels = levels (col)
51
+ col_type = eltype (feat_levels)
52
52
53
53
# Ensure column type is valid (can't test because never occurs)
54
54
# Converting array elements to strings before wrapping in a `CategoricalArray`, as...
@@ -88,7 +88,11 @@ function cardinality_reducer_fit(
88
88
elseif elgrandtype == Number
89
89
new_cat_given_col_val[level] = minimum (feat_levels) - 1
90
90
else
91
- throw (ArgumentError (UNSPECIFIED_COL_TYPE (col_type, label_for_infrequent)))
91
+ throw (
92
+ ArgumentError (
93
+ UNSPECIFIED_COL_TYPE (col_type, label_for_infrequent),
94
+ ),
95
+ )
92
96
end
93
97
end
94
98
end
@@ -98,7 +102,8 @@ function cardinality_reducer_fit(
98
102
99
103
# 2. Pass it to generic_fit
100
104
new_cat_given_col_val, encoded_features = generic_fit (
101
- X, features; ignore = ignore, ordered_factor = ordered_factor, feature_mapper = feature_mapper,
105
+ X, features; ignore = ignore, ordered_factor = ordered_factor,
106
+ feature_mapper = feature_mapper,
102
107
)
103
108
cache = Dict (
104
109
:new_cat_given_col_val => new_cat_given_col_val,
@@ -125,5 +130,5 @@ Apply a fitted cardinality reducer to a table given the output of `cardinality_r
125
130
"""
126
131
function cardinality_reducer_transform (X, cache:: Dict )
127
132
new_cat_given_col_val = cache[:new_cat_given_col_val ]
128
- return generic_transform (X, new_cat_given_col_val; ignore_unknown = true )
133
+ return generic_transform (X, new_cat_given_col_val; ignore_unknown = true , ensure_categorical = true )
129
134
end
0 commit comments