1+ using Test
2+ using Tables
3+ using CategoricalArrays
4+
5+
6+
7+ @testset " Generic Table Types Support" begin
8+
9+ # Create test data as in the issue
10+ x = vcat (collect (" abc" ), fill (' d' , 100 ))
11+ x = coerce (x, Multiclass)
12+
13+ # Column table (NamedTuple of vectors) - this already works
14+ coltable = (; x)
15+
16+ # Row table (Vector of NamedTuples) - this was failing
17+ rowtable = Tables. rowtable (coltable)
18+
19+ # List of models that were affected by the issue
20+ models_to_test = [
21+ CardinalityReducer (),
22+ FrequencyEncoder (),
23+ MissingnessEncoder (),
24+ OrdinalEncoder (),
25+ ]
26+
27+ @testset " Model: $(string (typeof (model))) " for model in models_to_test
28+
29+ @testset " Column Table Support" begin
30+ mach_col = machine (model, coltable)
31+ MLJBase. fit! (mach_col, verbosity= 0 )
32+ result_col = MLJBase. transform (mach_col, coltable)
33+
34+ @test ! isempty (Tables. columntable (result_col))
35+ end
36+
37+ @testset " Row Table Support" begin
38+ # This should now work after the fix
39+ mach_row = machine (model, rowtable)
40+ MLJBase. fit! (mach_row, verbosity= 0 )
41+ result_row = MLJBase. transform (mach_row, rowtable)
42+
43+ @test ! isempty (Tables. columntable (result_row))
44+ end
45+
46+ @testset " Consistency Between Table Types" begin
47+ # Results should be equivalent regardless of table type
48+ mach_col = machine (model, coltable)
49+ MLJBase. fit! (mach_col, verbosity= 0 )
50+ result_col = MLJBase. transform (mach_col, coltable)
51+
52+ mach_row = machine (model, rowtable)
53+ MLJBase. fit! (mach_row, verbosity= 0 )
54+ result_row = MLJBase. transform (mach_row, rowtable)
55+
56+ # Convert both to column tables for comparison
57+ result_col_ct = Tables. columntable (result_col)
58+ result_row_ct = Tables. columntable (result_row)
59+
60+ # Should have same column names
61+ @test keys (result_col_ct) == keys (result_row_ct)
62+
63+ # Should have same values (allowing for potential ordering differences in table types)
64+ for col_name in keys (result_col_ct)
65+ @test Set (result_col_ct[col_name]) == Set (result_row_ct[col_name])
66+ end
67+ end
68+ end
69+ end
0 commit comments