Skip to content

Commit e9a0c44

Browse files
committed
✨ Fix frequency encoder output types
1 parent f4ae7bd commit e9a0c44

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

src/encoders/ordinal_encoding/interface_mlj.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@ mutable struct OrdinalEncoder{AS <: AbstractVector{Symbol}} <: Unsupervised
55
features::AS
66
ignore::Bool
77
ordered_factor::Bool
8+
op_dtype::Type
89
end;
910

1011
# 2. Constructor
1112
function OrdinalEncoder(;
1213
features = Symbol[],
1314
ignore = true,
1415
ordered_factor = false,
16+
op_dtype = Float32,
1517
)
16-
return OrdinalEncoder(features, ignore, ordered_factor)
18+
return OrdinalEncoder(features, ignore, ordered_factor, op_dtype)
1719
end;
1820

1921

@@ -29,6 +31,7 @@ function MMI.fit(transformer::OrdinalEncoder, verbosity::Int, X)
2931
transformer.features;
3032
ignore = transformer.ignore,
3133
ordered_factor = transformer.ordered_factor,
34+
op_dtype = transformer.op_dtype,
3235
)
3336
fitresult =
3437
generic_cache[:index_given_feat_level]
@@ -92,6 +95,7 @@ Train the machine using `fit!(mach, rows=...)`.
9295
- `features=[]`: A list of names of categorical features given as symbols to exclude or include from encoding
9396
- `ignore=true`: Whether to exclude or includes the features given in `features`
9497
- `ordered_factor=false`: Whether to encode `OrderedFactor` or ignore them
98+
- `op_dtype`: The numerical concrete type of the encoded features. Default is `Float32`.
9599
96100
# Operations
97101

src/encoders/ordinal_encoding/ordinal_encoding.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Fit an encoder to encode the levels of categorical variables in a given table as
1010
- `features=[]`: A list of names of categorical features given as symbols to exclude or include from encoding
1111
- `ignore=true`: Whether to exclude or includes the features given in `features`
1212
- `ordered_factor=false`: Whether to encode `OrderedFactor` or ignore them
13-
13+
- `dtype`: The numerical concrete type of the encoded features. Default is `Float32`.
1414
# Returns (in a dict)
1515
1616
- `index_given_feat_level`: Maps each level for each column in a subset of the categorical features of X into an integer.
@@ -21,12 +21,13 @@ function ordinal_encoder_fit(
2121
features::AbstractVector{Symbol} = Symbol[];
2222
ignore::Bool = true,
2323
ordered_factor::Bool = false,
24+
op_dtype::Type = Float32,
2425
)
2526
# 1. Define feature mapper
2627
function feature_mapper(col, name)
2728
feat_levels = levels(col)
2829
index_given_feat_val =
29-
Dict{Any, Integer}(value => index for (index, value) in enumerate(feat_levels))
30+
Dict{eltype(feat_levels), op_dtype}(value => index for (index, value) in enumerate(feat_levels))
3031
return index_given_feat_val
3132
end
3233

test/encoders/ordinal_encoding.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,55 @@ end
8282
# Test report
8383
@test report(mach) == (encoded_features = generic_cache[:encoded_features],)
8484
end
85-
end
85+
end
86+
87+
88+
@testset "Test Ordinal Encoding Types" begin
89+
# Define categorical features
90+
A = ["g", "b", "g", "r", "r"]
91+
B = [1.0, 2.0, 3.0, 4.0, 5.0]
92+
C = ["f", "f", "f", "m", "f"]
93+
D = [true, false, true, false, true]
94+
E = [1, 2, 3, 4, 5]
95+
96+
# Combine into a named tuple
97+
X = (A = A, B = B, C = C, D = D, E = E)
98+
99+
# Coerce A, C, D to multiclass and B to continuous and E to ordinal
100+
X = coerce(X,
101+
:A => Multiclass,
102+
:B => Multiclass,
103+
:C => Multiclass,
104+
:D => Continuous,
105+
:E => OrderedFactor,
106+
)
107+
108+
109+
encoder = OrdinalEncoder(ordered_factor = false)
110+
mach = fit!(machine(encoder, X))
111+
Xnew = MMI.transform(mach, X)
112+
113+
scs = schema(Xnew).scitypes
114+
ts = schema(Xnew).types
115+
# Check scitypes for previously continuos or categorical features
116+
@test all(scs[1:end-1] .== Continuous)
117+
@test all(t -> (t <: AbstractFloat) && isconcretetype(t), ts[1:end-1])
118+
# Check that for last column it did not changed
119+
scs[end] === schema(X).scitypes[end]
120+
scs[end]
121+
schema(X).scitypes[end]
122+
123+
## Int32 case
124+
encoder = OrdinalEncoder(ordered_factor = false, op_dtype = Int32)
125+
mach = fit!(machine(encoder, X))
126+
Xnew = MMI.transform(mach, X)
127+
scs = schema(Xnew).scitypes
128+
ts = schema(Xnew).types
129+
# Check scitypes for previously categorical features
130+
@test all(scs[1:end-2] .== Count)
131+
@test all(t -> (t <: Integer) && isconcretetype(t), ts[1:end-2])
132+
# Check rest of the types
133+
scs[end-1:end]
134+
@test scs[end-1:end] == schema(X).scitypes[end-1:end]
135+
@test ts[end-1:end] == schema(X).types[end-1:end]
136+
end

0 commit comments

Comments
 (0)