Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@ logic?"

# Arguments

- `X`: A table where the elements of the categorical features have [scitypes](https://juliaai.github.io/ScientificTypes.jl/dev/)
`Multiclass` or `OrderedFactor`
- `features=[]`: A list of names of categorical features given as symbols to exclude or include from encoding
- `ignore=true`: Whether to exclude or includes the features given in `features`
- `ordered_factor=false`: Whether to encode `OrderedFactor` or ignore them
- `feature_mapper`: Defined above.
- X: A table where the elements of the categorical features have [scitypes](https://juliaai.github.io/ScientificTypes.jl/dev/)
Multiclass or OrderedFactor
- features=[]: A list of names of categorical features given as symbols to exclude or include from encoding,
or a callable that returns true for features to be included/excluded
- ignore=true: Whether to exclude or includes the features given in features
- ordered_factor=false: Whether to encode OrderedFactor or ignore them
- feature_mapper: Defined above.

# Returns

- `mapping_per_feat_level`: Maps each level for each feature in a subset of the categorical features of
- mapping_per_feat_level: Maps each level for each feature in a subset of the categorical features of
X into a scalar or a vector.
- `encoded_features`: The subset of the categorical features of X that were encoded
- encoded_features: The subset of the categorical features of X that were encoded
"""
function generic_fit(X,
features::AbstractVector{Symbol} = Symbol[],
features::Union{AbstractVector{Symbol}, Function} = Symbol[],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callables do not need to be functions. You can make instances of any new struct a callable. Unless you need it for type dispatch, there is no need to annotate a type here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 194c53e

args...;
ignore::Bool = true,
ordered_factor::Bool = false,
Expand All @@ -38,7 +39,17 @@ function generic_fit(X,
feat_names = Tables.schema(X).names

#2. Modify column_names based on features
feat_names = (ignore) ? setdiff(feat_names, features) : intersect(feat_names, features)
if features isa Function
# If features is a callable, apply it to each feature name
if ignore
feat_names = filter(name -> !features(name), feat_names)
else
feat_names = filter(features, feat_names)
end
else
# Original behavior for vector of symbols
feat_names = (ignore) ? setdiff(feat_names, features) : intersect(feat_names, features)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Julia it is unfortunately difficult to recognise callability of an object (at least last time I researched this). So, reverse your logic here: if feature names is a vector of symbols then do X, otherwise do Y.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example of a callable that is not a function:

struct Foo end
(::Foo)(x) = 2x

f = Foo()
f(4) # 8

f isa Function # false

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 5e0af90

# 3. Define mapping per column per level dictionary
mapping_per_feat_level = Dict()
Expand Down
25 changes: 19 additions & 6 deletions test/encoders/contrast_encoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ age = [23, 23, 14, 23])


@testset "Contrast Encoder Error Handling" begin

# Example definitions to allow the test to run
function dummy_buildmatrix(colname, k)
# Simple dummy function to generate a matrix of correct size
Expand All @@ -23,21 +22,35 @@ age = [23, 23, 14, 23])
)

# Test IGNORE_MUST_FALSE_VEC_MODE error
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=[:contrast], ignore=true)
@test_throws MLJTransforms.IGNORE_MUST_FALSE_VEC_MODE begin
contrast_encoder_fit(data, [:A], mode=[:contrast], ignore=true)
end

# Test LENGTH_MISMATCH_VEC_MODE error
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=[:contrast, :dummy], buildmatrix=dummy_buildmatrix, ignore=false)
@test_throws MLJTransforms.LENGTH_MISMATCH_VEC_MODE(2, 1) begin
contrast_encoder_fit(data, [:A], mode=[:contrast, :dummy], buildmatrix=dummy_buildmatrix, ignore=false)
end

# Test BUILDFUNC_MUST_BE_SPECIFIED error
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=:contrast, ignore=false)
@test_throws MLJTransforms.BUILDFUNC_MUST_BE_SPECIFIED begin
contrast_encoder_fit(data, [:A], mode=:contrast, ignore=false)
end

# Test MATRIX_SIZE_ERROR
wrong_buildmatrix = (levels, k) -> randn(k, k) # Incorrect dimensions
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=:contrast, buildmatrix=wrong_buildmatrix, ignore=false)
k = 3 # Number of levels in data[:A]
wrong_size = (k, k)
@test_throws MLJTransforms.MATRIX_SIZE_ERROR(k, wrong_size, :A) begin
contrast_encoder_fit(data, [:A], mode=:contrast, buildmatrix=wrong_buildmatrix, ignore=false)
end

# Test MATRIX_SIZE_ERROR_HYP
wrong_buildmatrix_hyp = (levels, k) -> randn(k, k+1) # Incorrect dimensions for hypothesis matrix
@test_throws ArgumentError contrast_encoder_fit(data, [:A], mode=:hypothesis, buildmatrix=wrong_buildmatrix_hyp, ignore=false)
wrong_size_hyp = (k, k+1)
@test_throws MLJTransforms.MATRIX_SIZE_ERROR_HYP(k, wrong_size_hyp, :A) begin
contrast_encoder_fit(data, [:A], mode=:hypothesis, buildmatrix=wrong_buildmatrix_hyp, ignore=false)
end

end

@testset "Dummy Coding Tests" begin
Expand Down
13 changes: 9 additions & 4 deletions test/encoders/missingness_encoding.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
using MLJTransforms: missingness_encoder_fit, missingness_encoder_transform

@testset "Throws errors when needed" begin
@test_throws ArgumentError begin
@testset "Missingness Encoder Error Handling" begin
# Test COLLISION_NEW_VAL_ME error - when label_for_missing value already exists in levels
@test_throws MLJTransforms.COLLISION_NEW_VAL_ME("missing") begin
X = generate_X_with_missingness(;john_name="missing")
cache = missingness_encoder_fit(
X;
label_for_missing = Dict(AbstractString => "missing", Char => 'm'),
)
end
@test_throws ArgumentError begin

# Test VALID_TYPES_NEW_VAL_ME error - when label_for_missing key is not a supported type
@test_throws MLJTransforms.VALID_TYPES_NEW_VAL_ME(Bool) begin
X = generate_X_with_missingness()
cache = missingness_encoder_fit(
X;
label_for_missing = Dict(AbstractString => "Other", Bool => 'X'),
)
end
@test_throws ArgumentError begin

# Test UNSPECIFIED_COL_TYPE_ME error - when column type isn't in label_for_missing
@test_throws MLJTransforms.UNSPECIFIED_COL_TYPE_ME(Char, Dict(AbstractString => "X")) begin
X = generate_X_with_missingness()
cache = missingness_encoder_fit(
X;
Expand Down
14 changes: 8 additions & 6 deletions test/encoders/target_encoding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,15 @@ end
@test fitresult.task == generic_cache[:task]

# Test invalid `m`
@test_throws ArgumentError begin
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = 0.5, m = -5)
invalid_m = -5
@test_throws MLJTransforms.NON_NEGATIVE_m(invalid_m) begin
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = 0.5, m = invalid_m)
end

# Test invalid `lambda`
@test_throws ArgumentError begin
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = 1.1, m = 1)

# Test invalid `lambda` (value > 1)
invalid_lambda = 1.1
@test_throws MLJTransforms.INVALID_lambda(invalid_lambda) begin
t = TargetEncoder(ignore = true, ordered_factor = false, lambda = invalid_lambda, m = 1)
end

# Test report
Expand Down
23 changes: 22 additions & 1 deletion test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
# Dummy encoder that maps each level to its hash (some arbitrary function)
function dummy_encoder_fit(
X,
features::AbstractVector{Symbol} = Symbol[];
features = Symbol[];
ignore::Bool = true,
ordered_factor::Bool = false,
)
Expand All @@ -64,6 +64,7 @@ function dummy_encoder_fit(
)
cache = Dict(
:hash_given_feat_val => hash_given_feat_val,
:encoded => encoded_features,
)
return cache
end
Expand Down Expand Up @@ -144,4 +145,24 @@ end
F = [enc(:F, X[:F][i]) for i in 1:10]
)
@test X_tr == target
end

@testset "Callable feature functionality tests" begin
X = dataset_forms[1]
feat_names = Tables.schema(X).names

# Define a predicate: include only columns with name in uppercase list [:A, :C, :E]
predicate = name -> name in [:A, :C, :E]

# Test 1: ignore=true should exclude predicate columns
cache1 = dummy_encoder_fit(X, predicate; ignore=true, ordered_factor=false)
@test !(:A in cache1[:encoded]) && !(:C in cache1[:encoded]) && !(:E in cache1[:encoded])

# Test 2: ignore=false should include only predicate columns
cache2 = dummy_encoder_fit(X, predicate; ignore=false, ordered_factor=false)
@test Set(cache2[:encoded]) == Set([:A, :C])

# Test 3: predicate with ordered_factor=true picks up ordered factors (e.g., :E)
cache3 = dummy_encoder_fit(X, predicate; ignore=false, ordered_factor=true)
@test Set(cache3[:encoded]) == Set([:A, :C, :E])
end
16 changes: 11 additions & 5 deletions test/transformers/cardinality_reducer.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
using MLJTransforms: cardinality_reducer_fit, cardinality_reducer_transform



@testset "Throws errors when needed" begin
@test_throws ArgumentError begin
@testset "Cardinality Reducer Error Handling" begin
# Test COLLISION_NEW_VAL error - when label_for_infrequent value already exists in data
@test_throws MLJTransforms.COLLISION_NEW_VAL('X') begin
X = generate_high_cardinality_table(1000; obj = false, special_cat = 'X')
cache = cardinality_reducer_fit(
X;
label_for_infrequent = Dict(AbstractString => "Other", Char => 'X'),
)
end
@test_throws ArgumentError begin

# Test VALID_TYPES_NEW_VAL error - when label_for_infrequent key is not a supported type
@test_throws MLJTransforms.VALID_TYPES_NEW_VAL(Bool) begin
X = generate_high_cardinality_table(1000; obj = false, special_cat = 'O')
cache = cardinality_reducer_fit(
X;
label_for_infrequent = Dict(AbstractString => "Other", Bool => 'X'),
)
end
@test_throws ArgumentError begin

# Test UNSPECIFIED_COL_TYPE error - when column type isn't in label_for_infrequent
@test_throws MLJTransforms.UNSPECIFIED_COL_TYPE(Char, Dict(AbstractString => "X")) begin
X = generate_high_cardinality_table(1000)
cache = cardinality_reducer_fit(
X;
min_frequency = 30,
label_for_infrequent = Dict(AbstractString => "X"),
# Missing Char type in label_for_infrequent, which should be present in X
)
end

end


Expand Down
Loading