Skip to content

Commit 5808f8e

Browse files
committed
✨ Fix missingness encoder output types
1 parent 6c4589b commit 5808f8e

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

src/encoders/missingness_encoding/missingness_encoding.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function missingness_encoder_fit(
3030
features::AbstractVector{Symbol} = Symbol[];
3131
ignore::Bool = true,
3232
ordered_factor::Bool = false,
33-
label_for_missing::Dict{<:Type, <:Any} = Dict(
33+
label_for_missing::Dict{<:Type, <:Any} = Dict(
3434
AbstractString => "missing",
3535
Char => 'm',
3636
),
@@ -40,8 +40,8 @@ function missingness_encoder_fit(
4040

4141
# 1. Define feature mapper
4242
function feature_mapper(col, name)
43-
col_type = nonmissingtype(eltype(col)).parameters[1]
44-
feat_levels = levels(col; skipmissing=true)
43+
feat_levels = levels(col; skipmissing = true)
44+
col_type = nonmissingtype(eltype(feat_levels))
4545

4646
# Ensure column type is valid (can't test because never occurs)
4747
# Converting array elements to strings before wrapping in a `CategoricalArray`, as...
@@ -58,7 +58,7 @@ function missingness_encoder_fit(
5858

5959
# Check no collision between keys(label_for_missing) and feat_levels
6060
for value in values(label_for_missing)
61-
if !ismissing(value)
61+
if !ismissing(value)
6262
if value in feat_levels
6363
throw(ArgumentError(COLLISION_NEW_VAL_ME(value)))
6464
end
@@ -73,7 +73,7 @@ function missingness_encoder_fit(
7373
break
7474
end
7575
end
76-
76+
7777
# Nonmissing levels remain as is
7878
label_for_missing_given_feature = Dict{Missing, col_type}()
7979

@@ -91,7 +91,8 @@ function missingness_encoder_fit(
9191

9292
# 2. Pass it to generic_fit
9393
label_for_missing_given_feature, encoded_features = generic_fit(
94-
X, features; ignore = ignore, ordered_factor = ordered_factor, feature_mapper = feature_mapper,
94+
X, features; ignore = ignore, ordered_factor = ordered_factor,
95+
feature_mapper = feature_mapper,
9596
)
9697
cache = Dict(
9798
:label_for_missing_given_feature => label_for_missing_given_feature,
@@ -117,6 +118,11 @@ Apply a fitted missingness encoder to a table given the output of `missingness_e
117118
"""
118119
function missingness_encoder_transform(X, cache::Dict)
119120
label_for_missing_given_feature = cache[:label_for_missing_given_feature]
120-
return generic_transform(X, label_for_missing_given_feature; ignore_unknown = true)
121+
return generic_transform(
122+
X,
123+
label_for_missing_given_feature;
124+
ignore_unknown = true,
125+
ensure_categorical = true,
126+
)
121127
end
122128

test/encoders/missingness_encoding.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,35 @@ end
153153

154154
# Test report
155155
@test report(mach) == (encoded_features = generic_cache[:encoded_features],)
156-
end
156+
end
157+
158+
159+
160+
@testset "Test Missingness Encoder Output Types" begin
161+
# Define a table with missing values
162+
Xm = (
163+
A = categorical(["Ben", "John", missing, missing, "Mary", "John", missing]),
164+
B = [1.85, 1.67, missing, missing, 1.5, 1.67, missing],
165+
C = categorical([7, 5, missing, missing, 10, 0, missing]),
166+
D = categorical([23, 23, 44, 66, 14, 23, missing], ordered = true),
167+
E = categorical([missing, 'g', 'r', missing, 'r', 'g', 'p']),
168+
)
169+
170+
encoder = MissingnessEncoder()
171+
mach = fit!(machine(encoder, Xm))
172+
Xnew = MMI.transform(mach, Xm)
173+
174+
schema(Xm)
175+
schema(Xnew)
176+
Xnew.B
177+
178+
scs = schema(Xnew).scitypes
179+
for (i, type) in enumerate(schema(Xm).scitypes)
180+
print(nonmissingtype(type))
181+
if nonmissingtype(type) <: Multiclass
182+
@test scs[i] <: Multiclass
183+
else
184+
scs[i] == type
185+
end
186+
end
187+
end

0 commit comments

Comments
 (0)