Skip to content

Commit f60f640

Browse files
committed
✨ Fix table bug
1 parent cbabb50 commit f60f640

File tree

3 files changed

+72
-2
lines changed

3 files changed

+72
-2
lines changed

src/generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function generic_fit(X,
5959
# 4. Use feature mapper to compute the mapping of each level in each column
6060
encoded_features = Symbol[]# to store column that were actually encoded
6161
for feat_name in feat_names
62-
feat_col = Tables.getcolumn(X, feat_name)
62+
feat_col = MMI.selectcols(X, feat_name)
6363
feat_type = elscitype(feat_col)
6464
feat_has_allowed_type =
6565
feat_type <: Union{Missing, Multiclass} ||
@@ -149,7 +149,7 @@ function generic_transform(
149149
new_feat_names = Symbol[]
150150
new_cols = []
151151
for feat_name in feat_names
152-
col = Tables.getcolumn(X, feat_name)
152+
col = MMI.selectcols(X, feat_name)
153153
# Create the transformation function for each column
154154
if feat_name in keys(mapping_per_feat_level)
155155
if !ignore_unknown

test/generic_table_types.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using Test
2+
using Tables
3+
using CategoricalArrays
4+
5+
6+
7+
@testset "Generic Table Types Support" begin
8+
9+
# Create test data as in the issue
10+
x = vcat(collect("abc"), fill('d', 100))
11+
x = coerce(x, Multiclass)
12+
13+
# Column table (NamedTuple of vectors) - this already works
14+
coltable = (; x)
15+
16+
# Row table (Vector of NamedTuples) - this was failing
17+
rowtable = Tables.rowtable(coltable)
18+
19+
# List of models that were affected by the issue
20+
models_to_test = [
21+
CardinalityReducer(),
22+
FrequencyEncoder(),
23+
MissingnessEncoder(),
24+
OrdinalEncoder(),
25+
]
26+
27+
@testset "Model: $(string(typeof(model)))" for model in models_to_test
28+
29+
@testset "Column Table Support" begin
30+
mach_col = machine(model, coltable)
31+
MLJBase.fit!(mach_col, verbosity=0)
32+
result_col = MLJBase.transform(mach_col, coltable)
33+
34+
@test !isempty(Tables.columntable(result_col))
35+
end
36+
37+
@testset "Row Table Support" begin
38+
# This should now work after the fix
39+
mach_row = machine(model, rowtable)
40+
MLJBase.fit!(mach_row, verbosity=0)
41+
result_row = MLJBase.transform(mach_row, rowtable)
42+
43+
@test !isempty(Tables.columntable(result_row))
44+
end
45+
46+
@testset "Consistency Between Table Types" begin
47+
# Results should be equivalent regardless of table type
48+
mach_col = machine(model, coltable)
49+
MLJBase.fit!(mach_col, verbosity=0)
50+
result_col = MLJBase.transform(mach_col, coltable)
51+
52+
mach_row = machine(model, rowtable)
53+
MLJBase.fit!(mach_row, verbosity=0)
54+
result_row = MLJBase.transform(mach_row, rowtable)
55+
56+
# Convert both to column tables for comparison
57+
result_col_ct = Tables.columntable(result_col)
58+
result_row_ct = Tables.columntable(result_row)
59+
60+
# Should have same column names
61+
@test keys(result_col_ct) == keys(result_row_ct)
62+
63+
# Should have same values (allowing for potential ordering differences in table types)
64+
for col_name in keys(result_col_ct)
65+
@test Set(result_col_ct[col_name]) == Set(result_row_ct[col_name])
66+
end
67+
end
68+
end
69+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ _get(x) = CategoricalArrays.DataAPI.unwrap(x)
2020

2121
include("utils.jl")
2222
include("generic.jl")
23+
include("generic_table_types.jl") # Test for issue #42 fix
2324
include("encoders/target_encoding.jl")
2425
include("encoders/ordinal_encoding.jl")
2526
include("encoders/frequency_encoder.jl")

0 commit comments

Comments
 (0)