Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/MLJTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Tables
# https://github.com/JuliaAI/MLJBase.jl/issues/1002
import ScientificTypes: elscitype, schema, coerce, ScientificTimeType
using MLJModelInterface # exports `scitype`, which will call `ScientificTypes.scitype`,
# once MLJBase is loaded (but this is not a dependency!)
# once MLJBase is loaded (but this is not a dependency!)
using CategoricalArrays
using TableOperations
using StatsBase
Expand All @@ -29,27 +29,27 @@ include("utils.jl")
include("encoders/target_encoding/errors.jl")
include("encoders/target_encoding/target_encoding.jl")
include("encoders/target_encoding/interface_mlj.jl")
export TargetEncoder
export TargetEncoder

# Ordinal encoding
include("encoders/ordinal_encoding/ordinal_encoding.jl")
include("encoders/ordinal_encoding/interface_mlj.jl")
export OrdinalEncoder
export OrdinalEncoder

# Frequency encoding
include("encoders/frequency_encoding/frequency_encoding.jl")
include("encoders/frequency_encoding/interface_mlj.jl")
export frequency_encoder_fit, frequency_encoder_transform, FrequencyEncoder
export FrequencyEncoder
export FrequencyEncoder

# Cardinality reduction
include("transformers/cardinality_reducer/cardinality_reducer.jl")
include("transformers/cardinality_reducer/interface_mlj.jl")
export cardinality_reducer_fit, cardinality_reducer_transform, CardinalityReducer
export CardinalityReducer
export CardinalityReducer
include("encoders/missingness_encoding/missingness_encoding.jl")
include("encoders/missingness_encoding/interface_mlj.jl")
export MissingnessEncoder
export MissingnessEncoder

# Contrast encoder
include("encoders/contrast_encoder/contrast_encoder.jl")
Expand All @@ -65,7 +65,6 @@ include("transformers/other_transformers/one_hot_encoder.jl")
include("transformers/other_transformers/standardizer.jl")
include("transformers/other_transformers/univariate_boxcox_transformer.jl")
include("transformers/other_transformers/univariate_discretizer.jl")
include("transformers/other_transformers/metadata_shared.jl")

export UnivariateDiscretizer,
UnivariateStandardizer, Standardizer, UnivariateBoxCoxTransformer,
Expand Down
4 changes: 2 additions & 2 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function generic_fit(X,
# 4. Use feature mapper to compute the mapping of each level in each column
encoded_features = Symbol[]# to store column that were actually encoded
for feat_name in feat_names
feat_col = Tables.getcolumn(X, feat_name)
feat_col = MMI.selectcols(X, feat_name)
feat_type = elscitype(feat_col)
feat_has_allowed_type =
feat_type <: Union{Missing, Multiclass} ||
Expand Down Expand Up @@ -149,7 +149,7 @@ function generic_transform(
new_feat_names = Symbol[]
new_cols = []
for feat_name in feat_names
col = Tables.getcolumn(X, feat_name)
col = MMI.selectcols(X, feat_name)
# Create the transformation function for each column
if feat_name in keys(mapping_per_feat_level)
if !ignore_unknown
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/other_transformers/continuous_encoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ metadata_model(ContinuousEncoder,
output_scitype = Table(Continuous),
load_path = "MLJTransforms.ContinuousEncoder")

# Package metadata for docstring generation
metadata_pkg(ContinuousEncoder,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(ContinuousEncoder))

Expand Down
16 changes: 15 additions & 1 deletion src/transformers/other_transformers/fill_imputer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ metadata_model(FillImputer,
output_scitype = Table,
load_path = "MLJTransforms.FillImputer")

# Package metadata for docstring generation
metadata_pkg(UnivariateFillImputer,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(UnivariateFillImputer))

Expand Down Expand Up @@ -294,7 +302,13 @@ For imputing tabular data, use [`FillImputer`](@ref).
"""
UnivariateFillImputer


# Package metadata for docstring generation
metadata_pkg(FillImputer,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(FillImputer))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ metadata_model(InteractionTransformer,
human_name = "interaction transformer",
load_path = "MLJTransforms.InteractionTransformer")

# Package metadata for docstring generation
metadata_pkg(InteractionTransformer,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(InteractionTransformer))

Expand Down
20 changes: 0 additions & 20 deletions src/transformers/other_transformers/metadata_shared.jl

This file was deleted.

8 changes: 8 additions & 0 deletions src/transformers/other_transformers/one_hot_encoder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ metadata_model(OneHotEncoder,
human_name = "one-hot encoder",
load_path = "MLJTransforms.OneHotEncoder")

# Package metadata for docstring generation
metadata_pkg(OneHotEncoder,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(OneHotEncoder))
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/other_transformers/standardizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ metadata_model(Standardizer,
output_scitype = Union{Table, AbstractVector{<:Continuous}},
load_path = "MLJTransforms.Standardizer")

# Package metadata for docstring generation
metadata_pkg(Standardizer,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(Standardizer))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ metadata_model(UnivariateBoxCoxTransformer,
human_name = "single variable Box-Cox transformer",
load_path = "MLJTransforms.UnivariateBoxCoxTransformer")

# Package metadata for docstring generation
metadata_pkg(UnivariateBoxCoxTransformer,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(UnivariateBoxCoxTransformer))

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/other_transformers/univariate_discretizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ metadata_model(UnivariateDiscretizer,
human_name = "single variable discretizer",
load_path = "MLJTransforms.UnivariateDiscretizer")

# Package metadata for docstring generation
metadata_pkg(UnivariateDiscretizer,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(UnivariateDiscretizer))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ metadata_model(UnivariateTimeTypeToContinuous,
"continuous representations of temporally typed data",
load_path = "MLJTransforms.UnivariateTimeTypeToContinuous")

# Package metadata for docstring generation
metadata_pkg(UnivariateTimeTypeToContinuous,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
package_license = "MIT")

"""
$(MLJModelInterface.doc_header(UnivariateTimeTypeToContinuous))

Expand Down
69 changes: 69 additions & 0 deletions test/generic_table_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Test
using Tables
using CategoricalArrays



@testset "Generic Table Types Support" begin

# Create test data as in the issue
x = vcat(collect("abc"), fill('d', 100))
x = coerce(x, Multiclass)

# Column table (NamedTuple of vectors) - this already works
coltable = (; x)

# Row table (Vector of NamedTuples) - this was failing
rowtable = Tables.rowtable(coltable)

# List of models that were affected by the issue
models_to_test = [
CardinalityReducer(),
FrequencyEncoder(),
MissingnessEncoder(),
OrdinalEncoder(),
]

@testset "Model: $(string(typeof(model)))" for model in models_to_test

@testset "Column Table Support" begin
mach_col = machine(model, coltable)
MLJBase.fit!(mach_col, verbosity=0)
result_col = MLJBase.transform(mach_col, coltable)

@test !isempty(Tables.columntable(result_col))
end

@testset "Row Table Support" begin
# This should now work after the fix
mach_row = machine(model, rowtable)
MLJBase.fit!(mach_row, verbosity=0)
result_row = MLJBase.transform(mach_row, rowtable)

@test !isempty(Tables.columntable(result_row))
end

@testset "Consistency Between Table Types" begin
# Results should be equivalent regardless of table type
mach_col = machine(model, coltable)
MLJBase.fit!(mach_col, verbosity=0)
result_col = MLJBase.transform(mach_col, coltable)

mach_row = machine(model, rowtable)
MLJBase.fit!(mach_row, verbosity=0)
result_row = MLJBase.transform(mach_row, rowtable)

# Convert both to column tables for comparison
result_col_ct = Tables.columntable(result_col)
result_row_ct = Tables.columntable(result_row)

# Should have same column names
@test keys(result_col_ct) == keys(result_row_ct)

# Should have same values (allowing for potential ordering differences in table types)
for col_name in keys(result_col_ct)
@test Set(result_col_ct[col_name]) == Set(result_row_ct[col_name])
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ _get(x) = CategoricalArrays.DataAPI.unwrap(x)

include("utils.jl")
include("generic.jl")
include("generic_table_types.jl") # Test for issue #42 fix
include("encoders/target_encoding.jl")
include("encoders/ordinal_encoding.jl")
include("encoders/frequency_encoder.jl")
Expand Down