diff --git a/src/encoders/missingness_encoding/missingness_encoding.jl b/src/encoders/missingness_encoding/missingness_encoding.jl index 848916c..4d3bf2e 100644 --- a/src/encoders/missingness_encoding/missingness_encoding.jl +++ b/src/encoders/missingness_encoding/missingness_encoding.jl @@ -35,7 +35,8 @@ function missingness_encoder_fit( Char => 'm', ), ) - supportedtypes = Union{Char, AbstractString, Number} + supportedtypes_list = [Char, AbstractString, Number] + supportedtypes = Union{supportedtypes_list...} # 1. Define feature mapper function feature_mapper(col, name) @@ -50,7 +51,7 @@ function missingness_encoder_fit( # Ensure label_for_missing keys are valid types for possible_col_type in keys(label_for_missing) - if !(possible_col_type in union_types(supportedtypes)) + if !(possible_col_type in supportedtypes_list) throw(ArgumentError(VALID_TYPES_NEW_VAL_ME(possible_col_type))) end end @@ -66,7 +67,7 @@ function missingness_encoder_fit( # Get ancestor type of column elgrandtype = nothing - for allowed_type in union_types(supportedtypes) + for allowed_type in supportedtypes_list if col_type <: allowed_type elgrandtype = allowed_type break diff --git a/src/transformers/cardinality_reducer/cardinality_reducer.jl b/src/transformers/cardinality_reducer/cardinality_reducer.jl index 1d8f531..18ca84c 100644 --- a/src/transformers/cardinality_reducer/cardinality_reducer.jl +++ b/src/transformers/cardinality_reducer/cardinality_reducer.jl @@ -41,7 +41,8 @@ function cardinality_reducer_fit( Char => 'O', ), ) - supportedtypes = Union{Char, AbstractString, Number} + supportedtypes_list = [Char, AbstractString, Number] + supportedtypes = Union{supportedtypes_list...} # 1. Define feature mapper function feature_mapper(col, name) @@ -57,7 +58,7 @@ function cardinality_reducer_fit( # Ensure label_for_infrequent keys are valid types for possible_col_type in keys(label_for_infrequent) - if !(possible_col_type in union_types(supportedtypes)) + if !(possible_col_type in supportedtypes_list) throw(ArgumentError(VALID_TYPES_NEW_VAL(possible_col_type))) end end @@ -71,7 +72,7 @@ function cardinality_reducer_fit( # Get ancestor type of column elgrandtype = nothing - for allowed_type in union_types(supportedtypes) + for allowed_type in supportedtypes_list if col_type <: allowed_type elgrandtype = allowed_type break diff --git a/src/utils.jl b/src/utils.jl index de21a05..8ac976c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1 @@ -# To go from e.g., Union{Integer, String} to (Integer, String) -union_types(x::Union) = (x.a, union_types(x.b)...) -union_types(x::Type) = (x,) \ No newline at end of file +# add utility functions here \ No newline at end of file diff --git a/test/transformers/cardinality_reducer.jl b/test/transformers/cardinality_reducer.jl index 6670306..4763683 100644 --- a/test/transformers/cardinality_reducer.jl +++ b/test/transformers/cardinality_reducer.jl @@ -1,8 +1,6 @@ -using MLJTransforms: union_types, cardinality_reducer_fit, cardinality_reducer_transform +using MLJTransforms: cardinality_reducer_fit, cardinality_reducer_transform + -@testset "Union_types" begin - @test union_types(Union{Integer, String}) == (Integer, String) -end @testset "Throws errors when needed" begin @test_throws ArgumentError begin