Skip to content

Commit 2f7bebb

Browse files
committed
✨ Contrast encoding should use level names as it generates columns
1 parent dd22fe5 commit 2f7bebb

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

src/MLJTransforms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using MLJModelInterface
77
using TableOperations
88
using StatsBase
99
using LinearAlgebra
10-
10+
using OrderedCollections: OrderedDict
1111
# Other transformers
1212
using Combinatorics
1313
import Distributions

src/encoders/contrast_encoder/contrast_encoder.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function contrast_encoder_fit(
125125
throw(ArgumentError("Mode $feat_mode is not supported."))
126126
end
127127

128-
vector_given_value_given_feature = Dict(level=>contrastmatrix[l, :] for (l, level) in enumerate(feat_levels))
128+
vector_given_value_given_feature = OrderedDict(level=>contrastmatrix[l, :] for (l, level) in enumerate(feat_levels))
129129
return vector_given_value_given_feature
130130
end
131131

@@ -159,5 +159,5 @@ Use a fitted contrast encoder to encode the levels of selected categorical varia
159159
"""
160160
function contrast_encoder_transform(X, cache::Dict)
161161
vector_given_value_given_feature = cache[:vector_given_value_given_feature]
162-
return generic_transform(X, vector_given_value_given_feature, single_feat = false)
162+
return generic_transform(X, vector_given_value_given_feature, single_feat = false; use_levelnames = true)
163163
end

src/encoders/contrast_encoder/interface_mlj.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ mach = fit!(machine(encoder, X))
148148
Xnew = transform(mach, X)
149149
150150
julia > Xnew
151-
(name_1 = [1.0, 0.0, 0.0, 0.0],
152-
name_2 = [0.0, 1.0, 0.0, 1.0],
151+
(name_John = [1.0, 0.0, 0.0, 0.0],
152+
name_Mary = [0.0, 1.0, 0.0, 1.0],
153153
height = [1.85, 1.67, 1.5, 1.67],
154-
favnum_1 = [0.0, 1.0, 0.0, -1.0],
155-
favnum_2 = [2.0, -1.0, 0.0, -1.0],
156-
favnum_3 = [-1.0, -1.0, 3.0, -1.0],
154+
favnum_5 = [0.0, 1.0, 0.0, -1.0],
155+
favnum_7 = [2.0, -1.0, 0.0, -1.0],
156+
favnum_10 = [-1.0, -1.0, 3.0, -1.0],
157157
age = [23, 23, 14, 23],)
158158
```
159159

0 commit comments

Comments
 (0)