Skip to content

Commit 7f12234

Browse files
committed
✨ Add callable features and better error testing
1 parent 14a5671 commit 7f12234

File tree

6 files changed

+90
-32
lines changed

6 files changed

+90
-32
lines changed

src/generic.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@ logic?"
1313
1414
# Arguments
1515
16-
- `X`: A table where the elements of the categorical features have [scitypes](https://juliaai.github.io/ScientificTypes.jl/dev/)
17-
`Multiclass` or `OrderedFactor`
18-
- `features=[]`: A list of names of categorical features given as symbols to exclude or include from encoding
19-
- `ignore=true`: Whether to exclude or includes the features given in `features`
20-
- `ordered_factor=false`: Whether to encode `OrderedFactor` or ignore them
21-
- `feature_mapper`: Defined above.
16+
- X: A table where the elements of the categorical features have [scitypes](https://juliaai.github.io/ScientificTypes.jl/dev/)
17+
Multiclass or OrderedFactor
18+
- features=[]: A list of names of categorical features given as symbols to exclude or include from encoding,
19+
or a callable that returns true for features to be included/excluded
20+
- ignore=true: Whether to exclude or includes the features given in features
21+
- ordered_factor=false: Whether to encode OrderedFactor or ignore them
22+
- feature_mapper: Defined above.
2223
2324
# Returns
2425
25-
- `mapping_per_feat_level`: Maps each level for each feature in a subset of the categorical features of
26+
- mapping_per_feat_level: Maps each level for each feature in a subset of the categorical features of
2627
X into a scalar or a vector.
27-
- `encoded_features`: The subset of the categorical features of X that were encoded
28+
- encoded_features: The subset of the categorical features of X that were encoded
2829
"""
2930
function generic_fit(X,
30-
features::AbstractVector{Symbol} = Symbol[],
31+
features::Union{AbstractVector{Symbol}, Function} = Symbol[],
3132
args...;
3233
ignore::Bool = true,
3334
ordered_factor::Bool = false,
@@ -38,7 +39,17 @@ function generic_fit(X,
3839
feat_names = Tables.schema(X).names
3940

4041
#2. Modify column_names based on features
41-
feat_names = (ignore) ? setdiff(feat_names, features) : intersect(feat_names, features)
42+
if features isa Function
43+
# If features is a callable, apply it to each feature name
44+
if ignore
45+
feat_names = filter(name -> !features(name), feat_names)
46+
else
47+
feat_names = filter(features, feat_names)
48+
end
49+
else
50+
# Original behavior for vector of symbols
51+
feat_names = (ignore) ? setdiff(feat_names, features) : intersect(feat_names, features)
52+
end
4253

4354
# 3. Define mapping per column per level dictionary
4455
mapping_per_feat_level = Dict()

test/encoders/contrast_encoder.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ age = [23, 23, 14, 23])
99

1010

1111
@testset "Contrast Encoder Error Handling" begin
12-
1312
# Example definitions to allow the test to run
1413
function dummy_buildmatrix(colname, k)
1514
# Simple dummy function to generate a matrix of correct size
@@ -23,21 +22,35 @@ age = [23, 23, 14, 23])
2322
)
2423

2524
# Test IGNORE_MUST_FALSE_VEC_MODE error
26-
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=[:contrast], ignore=true)
25+
@test_throws MLJTransforms.IGNORE_MUST_FALSE_VEC_MODE begin
26+
contrast_encoder_fit(data, [:A], mode=[:contrast], ignore=true)
27+
end
2728

2829
# Test LENGTH_MISMATCH_VEC_MODE error
29-
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=[:contrast, :dummy], buildmatrix=dummy_buildmatrix, ignore=false)
30+
@test_throws MLJTransforms.LENGTH_MISMATCH_VEC_MODE(2, 1) begin
31+
contrast_encoder_fit(data, [:A], mode=[:contrast, :dummy], buildmatrix=dummy_buildmatrix, ignore=false)
32+
end
3033

3134
# Test BUILDFUNC_MUST_BE_SPECIFIED error
32-
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=:contrast, ignore=false)
35+
@test_throws MLJTransforms.BUILDFUNC_MUST_BE_SPECIFIED begin
36+
contrast_encoder_fit(data, [:A], mode=:contrast, ignore=false)
37+
end
3338

3439
# Test MATRIX_SIZE_ERROR
3540
wrong_buildmatrix = (levels, k) -> randn(k, k) # Incorrect dimensions
36-
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=:contrast, buildmatrix=wrong_buildmatrix, ignore=false)
41+
k = 3 # Number of levels in data[:A]
42+
wrong_size = (k, k)
43+
@test_throws MLJTransforms.MATRIX_SIZE_ERROR(k, wrong_size, :A) begin
44+
contrast_encoder_fit(data, [:A], mode=:contrast, buildmatrix=wrong_buildmatrix, ignore=false)
45+
end
3746

3847
# Test MATRIX_SIZE_ERROR_HYP
3948
wrong_buildmatrix_hyp = (levels, k) -> randn(k, k+1) # Incorrect dimensions for hypothesis matrix
40-
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=:hypothesis, buildmatrix=wrong_buildmatrix_hyp, ignore=false)
49+
wrong_size_hyp = (k, k+1)
50+
@test_throws MLJTransforms.MATRIX_SIZE_ERROR_HYP(k, wrong_size_hyp, :A) begin
51+
contrast_encoder_fit(data, [:A], mode=:hypothesis, buildmatrix=wrong_buildmatrix_hyp, ignore=false)
52+
end
53+
4154
end
4255

4356
@testset "Dummy Coding Tests" begin

test/encoders/missingness_encoding.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
using MLJTransforms: missingness_encoder_fit, missingness_encoder_transform
22

3-
@testset "Throws errors when needed" begin
4-
@test_throws ArgumentError begin
3+
@testset "Missingness Encoder Error Handling" begin
4+
# Test COLLISION_NEW_VAL_ME error - when label_for_missing value already exists in levels
5+
@test_throws MLJTransforms.COLLISION_NEW_VAL_ME("missing") begin
56
X = generate_X_with_missingness(;john_name="missing")
67
cache = missingness_encoder_fit(
78
X;
89
label_for_missing = Dict(AbstractString => "missing", Char => 'm'),
910
)
1011
end
11-
@test_throws ArgumentError begin
12+
13+
# Test VALID_TYPES_NEW_VAL_ME error - when label_for_missing key is not a supported type
14+
@test_throws MLJTransforms.VALID_TYPES_NEW_VAL_ME(Bool) begin
1215
X = generate_X_with_missingness()
1316
cache = missingness_encoder_fit(
1417
X;
1518
label_for_missing = Dict(AbstractString => "Other", Bool => 'X'),
1619
)
1720
end
18-
@test_throws ArgumentError begin
21+
22+
# Test UNSPECIFIED_COL_TYPE_ME error - when column type isn't in label_for_missing
23+
@test_throws MLJTransforms.UNSPECIFIED_COL_TYPE_ME(Char, Dict(AbstractString => "X")) begin
1924
X = generate_X_with_missingness()
2025
cache = missingness_encoder_fit(
2126
X;

test/encoders/target_encoding.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,15 @@ end
333333
@test fitresult.task == generic_cache[:task]
334334

335335
# Test invalid `m`
336-
@test_throws ArgumentError begin
337-
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = 0.5, m = -5)
336+
invalid_m = -5
337+
@test_throws MLJTransforms.NON_NEGATIVE_m(invalid_m) begin
338+
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = 0.5, m = invalid_m)
338339
end
339-
340-
# Test invalid `lambda`
341-
@test_throws ArgumentError begin
342-
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = 1.1, m = 1)
340+
341+
# Test invalid `lambda` (value > 1)
342+
invalid_lambda = 1.1
343+
@test_throws MLJTransforms.INVALID_lambda(invalid_lambda) begin
344+
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = invalid_lambda, m = 1)
343345
end
344346

345347
# Test report

test/generic.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
# Dummy encoder that maps each level to its hash (some arbitrary function)
4646
function dummy_encoder_fit(
4747
X,
48-
features::AbstractVector{Symbol} = Symbol[];
48+
features = Symbol[];
4949
ignore::Bool = true,
5050
ordered_factor::Bool = false,
5151
)
@@ -64,6 +64,7 @@ function dummy_encoder_fit(
6464
)
6565
cache = Dict(
6666
:hash_given_feat_val => hash_given_feat_val,
67+
:encoded => encoded_features,
6768
)
6869
return cache
6970
end
@@ -144,4 +145,24 @@ end
144145
F = [enc(:F, X[:F][i]) for i in 1:10]
145146
)
146147
@test X_tr == target
148+
end
149+
150+
@testset "Callable feature functionality tests" begin
151+
X = dataset_forms[1]
152+
feat_names = Tables.schema(X).names
153+
154+
# Define a predicate: include only columns with name in uppercase list [:A, :C, :E]
155+
predicate = name -> name in [:A, :C, :E]
156+
157+
# Test 1: ignore=true should exclude predicate columns
158+
cache1 = dummy_encoder_fit(X, predicate; ignore=true, ordered_factor=false)
159+
@test !(:A in cache1[:encoded]) && !(:C in cache1[:encoded]) && !(:E in cache1[:encoded])
160+
161+
# Test 2: ignore=false should include only predicate columns
162+
cache2 = dummy_encoder_fit(X, predicate; ignore=false, ordered_factor=false)
163+
@test Set(cache2[:encoded]) == Set([:A, :C])
164+
165+
# Test 3: predicate with ordered_factor=true picks up ordered factors (e.g., :E)
166+
cache3 = dummy_encoder_fit(X, predicate; ignore=false, ordered_factor=true)
167+
@test Set(cache3[:encoded]) == Set([:A, :C, :E])
147168
end

test/transformers/cardinality_reducer.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,36 @@
11
using MLJTransforms: cardinality_reducer_fit, cardinality_reducer_transform
22

33

4-
5-
@testset "Throws errors when needed" begin
6-
@test_throws ArgumentError begin
4+
@testset "Cardinality Reducer Error Handling" begin
5+
# Test COLLISION_NEW_VAL error - when label_for_infrequent value already exists in data
6+
@test_throws MLJTransforms.COLLISION_NEW_VAL('X') begin
77
X = generate_high_cardinality_table(1000; obj = false, special_cat = 'X')
88
cache = cardinality_reducer_fit(
99
X;
1010
label_for_infrequent = Dict(AbstractString => "Other", Char => 'X'),
1111
)
1212
end
13-
@test_throws ArgumentError begin
13+
14+
# Test VALID_TYPES_NEW_VAL error - when label_for_infrequent key is not a supported type
15+
@test_throws MLJTransforms.VALID_TYPES_NEW_VAL(Bool) begin
1416
X = generate_high_cardinality_table(1000; obj = false, special_cat = 'O')
1517
cache = cardinality_reducer_fit(
1618
X;
1719
label_for_infrequent = Dict(AbstractString => "Other", Bool => 'X'),
1820
)
1921
end
20-
@test_throws ArgumentError begin
22+
23+
# Test UNSPECIFIED_COL_TYPE error - when column type isn't in label_for_infrequent
24+
@test_throws MLJTransforms.UNSPECIFIED_COL_TYPE(Char, Dict(AbstractString => "X")) begin
2125
X = generate_high_cardinality_table(1000)
2226
cache = cardinality_reducer_fit(
2327
X;
2428
min_frequency = 30,
2529
label_for_infrequent = Dict(AbstractString => "X"),
30+
# Missing Char type in label_for_infrequent, which should be present in X
2631
)
2732
end
33+
2834
end
2935

3036

0 commit comments

Comments
 (0)