Skip to content

Commit 5358dce

Browse files
committed
✅ Cache is now a named tuple across all methods
1 parent 08d973f commit 5358dce

File tree

19 files changed

+414
-256
lines changed

19 files changed

+414
-256
lines changed

src/encoders/contrast_encoder/contrast_encoder.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,9 @@ function contrast_encoder_fit(
134134
X, features; ignore = ignore, ordered_factor = ordered_factor,
135135
feature_mapper = feature_mapper,
136136
)
137-
138-
cache = Dict(
139-
:vector_given_value_given_feature => vector_given_value_given_feature,
140-
:encoded_features => encoded_features,
137+
cache = (
138+
vector_given_value_given_feature = vector_given_value_given_feature,
139+
encoded_features = encoded_features,
141140
)
142141

143142
return cache
@@ -157,7 +156,7 @@ Use a fitted contrast encoder to encode the levels of selected categorical varia
157156
158157
- `X_tr`: The table with selected features after the selected features are encoded by contrast encoding.
159158
"""
160-
function contrast_encoder_transform(X, cache::Dict)
161-
vector_given_value_given_feature = cache[:vector_given_value_given_feature]
159+
function contrast_encoder_transform(X, cache::NamedTuple)
160+
vector_given_value_given_feature = cache.vector_given_value_given_feature
162161
return generic_transform(X, vector_given_value_given_feature, single_feat = false; use_levelnames = true)
163162
end

src/encoders/contrast_encoder/interface_mlj.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ function MMI.fit(transformer::ContrastEncoder, verbosity::Int, X)
3636
buildmatrix = transformer.buildmatrix,
3737
ordered_factor = transformer.ordered_factor,
3838
)
39-
fitresult = generic_cache[:vector_given_value_given_feature]
39+
fitresult = generic_cache.vector_given_value_given_feature
4040

41-
report = (encoded_features = generic_cache[:encoded_features],) # report only has list of encoded features
41+
report = (encoded_features = generic_cache.encoded_features,) # report only has list of encoded features
4242
cache = nothing
4343
return fitresult, cache, report
4444
end;
4545

4646

4747
# 6. Transform method
4848
function MMI.transform(transformer::ContrastEncoder, fitresult, Xnew)
49-
generic_cache = Dict(
50-
:vector_given_value_given_feature =>
51-
fitresult,
49+
generic_cache = (
50+
vector_given_value_given_feature = fitresult,
5251
)
5352
Xnew_transf = contrast_encoder_transform(Xnew, generic_cache)
5453
return Xnew_transf

src/encoders/frequency_encoding/frequency_encoding.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ function frequency_encoder_fit(
3939
# 2. Pass it to generic_fit
4040
statistic_given_feat_val, encoded_features = generic_fit(
4141
X, features; ignore = ignore, ordered_factor = ordered_factor,
42-
feature_mapper = feature_mapper,
43-
)
44-
cache = Dict(
45-
:statistic_given_feat_val => statistic_given_feat_val,
46-
:encoded_features => encoded_features,
42+
feature_mapper = feature_mapper,)
43+
44+
cache = (
45+
statistic_given_feat_val = statistic_given_feat_val,
46+
encoded_features = encoded_features,
4747
)
4848
return cache
4949
end
@@ -62,7 +62,7 @@ Encode the levels of a categorical variable in a given table with their (normali
6262
6363
- `X_tr`: The table with selected features after the selected features are encoded by frequency encoding.
6464
"""
65-
function frequency_encoder_transform(X, cache::Dict)
66-
statistic_given_feat_val = cache[:statistic_given_feat_val]
65+
function frequency_encoder_transform(X, cache::NamedTuple)
66+
statistic_given_feat_val = cache.statistic_given_feat_val
6767
return generic_transform(X, statistic_given_feat_val)
6868
end

src/encoders/frequency_encoding/interface_mlj.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ function MMI.fit(transformer::FrequencyEncoder, verbosity::Int, X)
3636
normalize = transformer.normalize,
3737
output_type = transformer.output_type,
3838
)
39-
fitresult = generic_cache[:statistic_given_feat_val]
39+
fitresult = generic_cache.statistic_given_feat_val
4040

41-
report = (encoded_features = generic_cache[:encoded_features],) # report only has list of encoded features
41+
report = (encoded_features = generic_cache.encoded_features,) # report only has list of encoded features
4242
cache = nothing
4343
return fitresult, cache, report
4444
end;
4545

4646

4747
# 6. Transform method
4848
function MMI.transform(transformer::FrequencyEncoder, fitresult, Xnew)
49-
generic_cache = Dict(
50-
:statistic_given_feat_val =>
51-
fitresult,
49+
generic_cache = (
50+
statistic_given_feat_val = fitresult,
5251
)
5352
Xnew_transf = frequency_encoder_transform(Xnew, generic_cache)
5453
return Xnew_transf

src/encoders/missingness_encoding/interface_mlj.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,18 @@ function MMI.fit(transformer::MissingnessEncoder, verbosity::Int, X)
3939
ordered_factor = transformer.ordered_factor,
4040
label_for_missing = transformer.label_for_missing,
4141
)
42-
fitresult = generic_cache[:label_for_missing_given_feature]
42+
fitresult = generic_cache.label_for_missing_given_feature
4343

44-
report = (encoded_features = generic_cache[:encoded_features],) # report only has list of encoded features
44+
report = (encoded_features = generic_cache.encoded_features,) # report only has list of encoded features
4545
cache = nothing
4646
return fitresult, cache, report
4747
end;
4848

4949

5050
# 6. Transform method
5151
function MMI.transform(transformer::MissingnessEncoder, fitresult, Xnew)
52-
generic_cache = Dict(
53-
:label_for_missing_given_feature =>
54-
fitresult,
52+
generic_cache = (
53+
label_for_missing_given_feature = fitresult,
5554
)
5655
Xnew_transf = missingness_encoder_transform(Xnew, generic_cache)
5756
return Xnew_transf

src/encoders/missingness_encoding/missingness_encoding.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ function missingness_encoder_fit(
9494
X, features; ignore = ignore, ordered_factor = ordered_factor,
9595
feature_mapper = feature_mapper,
9696
)
97-
cache = Dict(
98-
:label_for_missing_given_feature => label_for_missing_given_feature,
99-
:encoded_features => encoded_features,
97+
cache = (
98+
label_for_missing_given_feature = label_for_missing_given_feature,
99+
encoded_features = encoded_features,
100100
)
101101
return cache
102102
end
@@ -116,8 +116,8 @@ Apply a fitted missingness encoder to a table given the output of `missingness_e
116116
117117
- `X_tr`: The table with selected features after the selected features are transformed by missingness encoder
118118
"""
119-
function missingness_encoder_transform(X, cache::Dict)
120-
label_for_missing_given_feature = cache[:label_for_missing_given_feature]
119+
function missingness_encoder_transform(X, cache::NamedTuple)
120+
label_for_missing_given_feature = cache.label_for_missing_given_feature
121121
return generic_transform(
122122
X,
123123
label_for_missing_given_feature;

src/encoders/ordinal_encoding/interface_mlj.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,16 @@ function MMI.fit(transformer::OrdinalEncoder, verbosity::Int, X)
3434
output_type = transformer.output_type,
3535
)
3636
fitresult =
37-
generic_cache[:index_given_feat_level]
38-
report = (encoded_features = generic_cache[:encoded_features],) # report only has list of encoded features
37+
generic_cache.index_given_feat_level
38+
report = (encoded_features = generic_cache.encoded_features,) # report only has list of encoded features
3939
cache = nothing
4040
return fitresult, cache, report
4141
end;
4242

4343

4444
# 6. Transform method
4545
function MMI.transform(transformer::OrdinalEncoder, fitresult, Xnew)
46-
generic_cache = Dict(
47-
:index_given_feat_level => fitresult,
48-
)
46+
generic_cache = (index_given_feat_level = fitresult,)
4947
Xnew_transf = ordinal_encoder_transform(Xnew, generic_cache)
5048
return Xnew_transf
5149
end

src/encoders/ordinal_encoding/ordinal_encoding.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Fit an encoder to encode the levels of categorical variables in a given table as
1111
- `ignore=true`: Whether to exclude or includes the features given in `features`
1212
- `ordered_factor=false`: Whether to encode `OrderedFactor` or ignore them
1313
- `dtype`: The numerical concrete type of the encoded features. Default is `Float32`.
14+
1415
# Returns (in a dict)
1516
1617
- `index_given_feat_level`: Maps each level for each column in a subset of the categorical features of X into an integer.
@@ -27,18 +28,19 @@ function ordinal_encoder_fit(
2728
function feature_mapper(col, name)
2829
feat_levels = levels(col)
2930
index_given_feat_val =
30-
Dict{eltype(feat_levels), output_type}(value => index for (index, value) in enumerate(feat_levels))
31+
Dict{eltype(feat_levels), output_type}(
32+
value => index for (index, value) in enumerate(feat_levels)
33+
)
3134
return index_given_feat_val
3235
end
3336

3437
# 2. Pass it to generic_fit
3538
index_given_feat_level, encoded_features = generic_fit(
3639
X, features; ignore = ignore, ordered_factor = ordered_factor,
37-
feature_mapper = feature_mapper,
38-
)
39-
cache = Dict(
40-
:index_given_feat_level => index_given_feat_level,
41-
:encoded_features => encoded_features,
40+
feature_mapper = feature_mapper,)
41+
cache = (
42+
index_given_feat_level = index_given_feat_level,
43+
encoded_features = encoded_features,
4244
)
4345
return cache
4446
end
@@ -58,7 +60,7 @@ Encode the levels of a categorical variable in a given table as integers.
5860
5961
- `X_tr`: The table with selected features after the selected features are encoded by ordinal encoding.
6062
"""
61-
function ordinal_encoder_transform(X, cache::Dict)
62-
index_given_feat_level = cache[:index_given_feat_level]
63+
function ordinal_encoder_transform(X, cache::NamedTuple)
64+
index_given_feat_level = cache.index_given_feat_level
6365
return generic_transform(X, index_given_feat_level)
6466
end

src/encoders/target_encoding/interface_mlj.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct TargetEncoderResult{
5252
task::S # "Regression", "Classification"
5353
num_classes::I # num_classes in case of classification
5454
y_classes::A # y_classes in case of classification
55-
55+
5656
end
5757

5858

@@ -75,25 +75,24 @@ function MMI.fit(transformer::TargetEncoder, verbosity::Int, X, y)
7575
m = transformer.m,
7676
)
7777
fitresult = TargetEncoderResult(
78-
generic_cache[:y_stat_given_feat_level],
79-
generic_cache[:task],
80-
generic_cache[:num_classes],
81-
generic_cache[:y_classes],
78+
generic_cache.y_stat_given_feat_level,
79+
generic_cache.task,
80+
generic_cache.num_classes,
81+
generic_cache.y_classes,
8282
)
83-
report = (encoded_features = generic_cache[:encoded_features],) # report only has list of encoded features
83+
report = (encoded_features = generic_cache.encoded_features,) # report only has list of encoded features
8484
cache = nothing
8585
return fitresult, cache, report
8686
end;
8787

8888

8989
# 7. Transform method
9090
function MMI.transform(transformer::TargetEncoder, fitresult, Xnew)
91-
generic_cache = Dict(
92-
:y_stat_given_feat_level =>
93-
fitresult.y_stat_given_feat_level,
94-
:num_classes => fitresult.num_classes,
95-
:task => fitresult.task,
96-
:y_classes => fitresult.y_classes,
91+
generic_cache = (
92+
y_stat_given_feat_level = fitresult.y_stat_given_feat_level,
93+
num_classes = fitresult.num_classes,
94+
task = fitresult.task,
95+
y_classes = fitresult.y_classes,
9796
)
9897
Xnew_transf = target_encoder_transform(Xnew, generic_cache)
9998
return Xnew_transf

src/encoders/target_encoding/target_encoding.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,12 @@ function target_encoder_fit(
211211
feature_mapper = feature_mapper,
212212
)
213213

214-
cache = Dict(
215-
:task => task,
216-
:num_classes => (task == "Regression") ? -1 : length(y_classes),
217-
:y_stat_given_feat_level => y_stat_given_feat_level,
218-
:encoded_features => encoded_features,
219-
:y_classes => (task == "Regression") ? nothing : y_classes,
214+
cache = (
215+
task = task,
216+
num_classes = (task == "Regression") ? -1 : length(y_classes),
217+
y_stat_given_feat_level = y_stat_given_feat_level,
218+
encoded_features = encoded_features,
219+
y_classes = (task == "Regression") ? nothing : y_classes,
220220
)
221221
return cache
222222
end
@@ -242,16 +242,16 @@ every categorical feature as well as other metadata needed for transform
242242
"""
243243

244244
function target_encoder_transform(X, cache)
245-
task = cache[:task]
246-
y_stat_given_feat_level = cache[:y_stat_given_feat_level]
247-
num_classes = cache[:num_classes]
248-
y_classes = cache[:y_classes]
245+
task = cache.task
246+
y_stat_given_feat_level = cache.y_stat_given_feat_level
247+
num_classes = cache.num_classes
248+
y_classes = cache.y_classes
249249

250250
return generic_transform(
251251
X,
252252
y_stat_given_feat_level;
253253
single_feat = task == "Regression" || (task == "Classification" && num_classes < 3),
254254
use_levelnames = true,
255-
custom_levels = y_classes,)
255+
custom_levels = y_classes)
256256
end
257257

0 commit comments

Comments
 (0)