Skip to content

Commit dd22fe5

Browse files
committed
✨ Class names to be used as level names for target encoding
1 parent a815b76 commit dd22fe5

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

src/encoders/target_encoding/interface_mlj.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ struct TargetEncoderResult{
5151
y_stat_given_feat_level::Dict{A, A}
5252
task::S # "Regression", "Classification"
5353
num_classes::I # num_classes in case of classification
54+
y_classes::A # y_classes in case of classification
55+
5456
end
5557

5658

@@ -76,6 +78,7 @@ function MMI.fit(transformer::TargetEncoder, verbosity::Int, X, y)
7678
generic_cache[:y_stat_given_feat_level],
7779
generic_cache[:task],
7880
generic_cache[:num_classes],
81+
generic_cache[:y_classes],
7982
)
8083
report = (encoded_features = generic_cache[:encoded_features],) # report only has list of encoded features
8184
cache = nothing
@@ -90,6 +93,7 @@ function MMI.transform(transformer::TargetEncoder, fitresult, Xnew)
9093
fitresult.y_stat_given_feat_level,
9194
:num_classes => fitresult.num_classes,
9295
:task => fitresult.task,
96+
:y_classes => fitresult.y_classes,
9397
)
9498
Xnew_transf = target_encoder_transform(Xnew, generic_cache)
9599
return Xnew_transf

src/encoders/target_encoding/target_encoding.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ function target_encoder_fit(
215215
:num_classes => (task == "Regression") ? -1 : length(y_classes),
216216
:y_stat_given_feat_level => y_stat_given_feat_level,
217217
:encoded_features => encoded_features,
218+
:y_classes => (task == "Regression") ? nothing : y_classes,
218219
)
219220
return cache
220221
end
@@ -243,11 +244,13 @@ function target_encoder_transform(X, cache)
243244
task = cache[:task]
244245
y_stat_given_feat_level = cache[:y_stat_given_feat_level]
245246
num_classes = cache[:num_classes]
247+
y_classes = cache[:y_classes]
246248

247249
return generic_transform(
248250
X,
249251
y_stat_given_feat_level;
250252
single_feat = task == "Regression" || (task == "Classification" && num_classes < 3),
251-
)
253+
use_levelnames = true,
254+
custom_levels = y_classes,)
252255
end
253256

test/encoders/target_encoding.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,22 +277,21 @@ end
277277
X_tr = target_encoder_transform(X, cache)
278278

279279
enc = (col, level) -> cache[:y_stat_given_feat_level][col][level]
280-
281280
target = (
282-
A_1 = [enc(:A, X[:A][i])[1] for i in 1:10],
283-
A_2 = [enc(:A, X[:A][i])[2] for i in 1:10],
284-
A_3 = [enc(:A, X[:A][i])[3] for i in 1:10],
281+
A_0 = [enc(:A, X[:A][i])[1] for i in 1:10],
282+
A_1 = [enc(:A, X[:A][i])[2] for i in 1:10],
283+
A_2 = [enc(:A, X[:A][i])[3] for i in 1:10],
285284
B = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
286-
C_1 = [enc(:C, X[:C][i])[1] for i in 1:10],
287-
C_2 = [enc(:C, X[:C][i])[2] for i in 1:10],
288-
C_3 = [enc(:C, X[:C][i])[3] for i in 1:10],
289-
D_1 = [enc(:D, X[:D][i])[1] for i in 1:10],
290-
D_2 = [enc(:D, X[:D][i])[2] for i in 1:10],
291-
D_3 = [enc(:D, X[:D][i])[3] for i in 1:10],
285+
C_0 = [enc(:C, X[:C][i])[1] for i in 1:10],
286+
C_1 = [enc(:C, X[:C][i])[2] for i in 1:10],
287+
C_2 = [enc(:C, X[:C][i])[3] for i in 1:10],
288+
D_0 = [enc(:D, X[:D][i])[1] for i in 1:10],
289+
D_1 = [enc(:D, X[:D][i])[2] for i in 1:10],
290+
D_2 = [enc(:D, X[:D][i])[3] for i in 1:10],
292291
E = [1, 2, 3, 4, 5, 6, 6, 3, 2, 1],
293-
F_1 = [enc(:F, X[:F][i])[1] for i in 1:10],
294-
F_2 = [enc(:F, X[:F][i])[2] for i in 1:10],
295-
F_3 = [enc(:F, X[:F][i])[3] for i in 1:10],
292+
F_0 = [enc(:F, X[:F][i])[1] for i in 1:10],
293+
F_1 = [enc(:F, X[:F][i])[2] for i in 1:10],
294+
F_2 = [enc(:F, X[:F][i])[3] for i in 1:10],
296295
)
297296
for col in keys(target)
298297
@test all(X_tr[col] .== target[col])

0 commit comments

Comments
 (0)