Skip to content

Commit f4ae7bd

Browse files
committed
✨ Fix target encoder output types
1 parent 7b577d7 commit f4ae7bd

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/encoders/target_encoding/target_encoding.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ function target_encoder_fit(
166166

167167
# 3. Define function to compute the new value(s) for each level given a column
168168
function feature_mapper(col, name)
169+
feat_levels = levels(col)
169170
y_stat_given_feat_level_for_col =
170-
Dict{Any, Union{AbstractFloat, AbstractVector{<:AbstractFloat}}}()
171+
Dict{eltype(feat_levels), Any}()
171172
for level in levels(col)
172173
# Get the targets of an example that belong to this level
173174
targets_for_level = y[col.==level]

test/encoders/target_encoding.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,4 +345,44 @@ end
345345
# Test report
346346
@test report(mach) == (encoded_features = generic_cache[:encoded_features],)
347347
end
348-
end
348+
end
349+
350+
351+
352+
@testset "Test Target Encoding Types" begin
353+
# Define categorical features
354+
A = ["g", "b", "g", "r", "r"]
355+
B = [1.0, 2.0, 3.0, 4.0, 5.0]
356+
C = ["f", "f", "f", "m", "f"]
357+
D = [true, false, true, false, true]
358+
E = [1, 2, 3, 4, 5]
359+
360+
# Define the target variable
361+
y = ["c1", "c2", "c3", "c1", "c2"]
362+
363+
# Combine into a named tuple
364+
X = (A = A, B = B, C = C, D = D, E = E)
365+
366+
# Coerce A, C, D to multiclass and B to continuous and E to ordinal
367+
X = coerce(X,
368+
:A => Multiclass,
369+
:B => Continuous,
370+
:C => Multiclass,
371+
:D => Multiclass,
372+
:E => OrderedFactor,
373+
)
374+
y = coerce(y, Multiclass)
375+
376+
encoder = TargetEncoder(ordered_factor = false, lambda = 1.0, m = 0)
377+
mach = fit!(machine(encoder, X, y))
378+
Xnew = MMI.transform(mach, X)
379+
380+
scs = schema(Xnew).scitypes
381+
ts = schema(Xnew).types
382+
# Check scitypes for previously continuos or categorical features
383+
@test all(scs[1:end-1] .== Continuous)
384+
@test all(t -> (t <: AbstractFloat) && isconcretetype(t), ts[1:end-1])
385+
@test scs[end] === schema(X).scitypes[end]
386+
@test ts[end] == schema(X).types[end]
387+
end
388+

0 commit comments

Comments
 (0)