Skip to content

Commit 2630f08

Browse files
committed
✨ Fix high cardinality dataset
1 parent 8bd0dc9 commit 2630f08

20 files changed

+103580
-3342
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,8 @@ meh/*.ipynb
2828
/*.jl
2929
scratchpad/
3030
examples/test.jl
31+
catboost_info/**
32+
/catboost_info
33+
/catboost_info
34+
/docs/src/tutorials/adult_example/.CondaPkg
35+
/docs/src/tutorials/adult_example/catboost_info

docs/make.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ makedocs(
3535
],
3636
],
3737
"Extended Examples" => Any[
38-
"Standardization Impact"=>"tutorials/standardization/notebook.md",
39-
"Milk Quality Classification"=>"tutorials/classic_comparison/notebook.md",
40-
"Wine Quality Prediction"=>"tutorials/wine_example/notebook.md",
41-
"Entity Embeddings Tutorial"=>"tutorials/entity_embeddings/notebook.md",
38+
"Standardization Impact" => "tutorials/standardization/notebook.md",
39+
"Milk Quality Classification" => "tutorials/classic_comparison/notebook.md",
40+
"Adult Income Classification" => "tutorials/adult_example/notebook.md",
41+
"Entity Embeddings Tutorial" => "tutorials/entity_embeddings/notebook.md",
4242
],
4343
"Contributing" => "contributing.md",
4444
"About" => "about.md",

docs/src/generate.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
function generate(dir; execute = true, pluto = false)
22
quote
33
using Pkg
4-
Pkg.activate(temp = true)
4+
# Activate the specific tutorial directory instead of temp environment
5+
Pkg.activate($dir)
6+
Pkg.instantiate()
57
Pkg.add("Literate")
68
using Literate
79

docs/src/tutorials/wine_example/Manifest.toml renamed to docs/src/tutorials/adult_example/Manifest.toml

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.5"
44
manifest_format = "2.0"
5-
project_hash = "5a14cc2e68cb2e8e5e7b95aca3553ebdf3e9929e"
5+
project_hash = "2bc58e2f1d5ca6172834e6bda17b630aa9b5ac28"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "be7ae030256b8ef14a441726c4c37766b90b93a3"
@@ -165,6 +165,12 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
165165
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
166166
version = "0.1.1"
167167

168+
[[deps.BenchmarkTools]]
169+
deps = ["Compat", "JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"]
170+
git-tree-sha1 = "e38fbc49a620f5d0b660d7f543db1009fe0f8336"
171+
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
172+
version = "1.6.0"
173+
168174
[[deps.BitBasis]]
169175
deps = ["LinearAlgebra", "StaticArrays"]
170176
git-tree-sha1 = "89dc08420d4f593ff30f02611d136b475a5eb43d"
@@ -766,6 +772,12 @@ git-tree-sha1 = "68c173f4f449de5b438ee67ed0c9c748dc31a2ec"
766772
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
767773
version = "0.3.28"
768774

775+
[[deps.IOCapture]]
776+
deps = ["Logging", "Random"]
777+
git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770"
778+
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
779+
version = "0.2.5"
780+
769781
[[deps.InitialValues]]
770782
git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3"
771783
uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c"
@@ -1059,6 +1071,12 @@ weakdeps = ["ChainRulesCore", "SparseArrays", "Statistics"]
10591071
LinearMapsSparseArraysExt = "SparseArrays"
10601072
LinearMapsStatisticsExt = "Statistics"
10611073

1074+
[[deps.Literate]]
1075+
deps = ["Base64", "IOCapture", "JSON", "REPL"]
1076+
git-tree-sha1 = "da046be6d63304f7ba9c1bb04820fb306ba1ab12"
1077+
uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
1078+
version = "2.20.1"
1079+
10621080
[[deps.LogExpFunctions]]
10631081
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
10641082
git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f"
@@ -1162,10 +1180,10 @@ uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
11621180
version = "0.17.9"
11631181

11641182
[[deps.MLJTransforms]]
1165-
deps = ["BitBasis", "CategoricalArrays", "Combinatorics", "Dates", "Distributions", "LinearAlgebra", "MLJModelInterface", "OrderedCollections", "Parameters", "ScientificTypes", "Statistics", "StatsBase", "TableOperations", "Tables"]
1183+
deps = ["BitBasis", "CategoricalArrays", "Combinatorics", "Dates", "Distributions", "LinearAlgebra", "MLJModelInterface", "OrderedCollections", "Parameters", "ScientificTypes", "ScientificTypesBase", "Statistics", "StatsBase", "TableOperations", "Tables"]
11661184
path = "/Users/essamwisam/Documents/GitHub/MLJTransforms"
11671185
uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6"
1168-
version = "0.1.6"
1186+
version = "0.1.1"
11691187

11701188
[[deps.MLJTuning]]
11711189
deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"]
@@ -1502,6 +1520,10 @@ deps = ["Unicode"]
15021520
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
15031521
version = "1.11.0"
15041522

1523+
[[deps.Profile]]
1524+
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
1525+
version = "1.11.0"
1526+
15051527
[[deps.ProgressMeter]]
15061528
deps = ["Distributed", "Printf"]
15071529
git-tree-sha1 = "13c5103482a8ed1536a54c08d0e742ae3dca2d42"

docs/src/tutorials/wine_example/Project.toml renamed to docs/src/tutorials/adult_example/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
23
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
34
CatBoost = "e2e10f9a-a85d-4fa9-b6b2-639a32100a12"
45
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
56
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
67
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
78
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
89
LightGBM = "7acf609c-83a4-11e9-1ffb-b912bcd3b04a"
10+
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
911
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
1012
MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
1113
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
31.2 KB
Loading
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Use the per-tutorial environment defined by `Project.toml` in this folder
2+
joinpath(@__DIR__, "..", "..", "generate.jl") |> include
3+
generate(@__DIR__, execute = true)
File renamed without changes.
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# # Adult Income Prediction: Comparing Categorical Encoders
2+
3+
# **Julia version** is assumed to be 1.10.*
4+
5+
# This demonstration is available as a Jupyter notebook or julia script (as well as the dataset)
6+
# [here](https://github.com/essamwise/MLJTransforms.jl/tree/main/docs/src/tutorials/wine_example).
7+
#
8+
# This tutorial compares different categorical encoding approaches on adult income prediction.
9+
# We'll test OneHot, Frequency, and Cardinality Reduction encoders with CatBoost classification.
10+
#
11+
# **Why compare encoders?** Categorical variables with many levels (like occupation, education)
12+
# can create high-dimensional sparse features. Different encoding strategies handle this
13+
# challenge differently, affecting both model performance and training speed.
14+
#
15+
# **High Cardinality Challenge:** We've added a synthetic feature with 100 categories to
16+
# demonstrate how encoders handle extreme cardinality - a common real-world scenario with
17+
# features like customer IDs, product codes, or geographical subdivisions.
18+
19+
# packages are already activated by generate.jl
20+
21+
using MLJ, MLJTransforms, DataFrames, ScientificTypes
22+
using Random, CSV, StatsBase, Plots, BenchmarkTools
23+
24+
# Import scitypes from MLJ to avoid any package version skew
25+
using MLJ: OrderedFactor, Continuous, Multiclass
26+
27+
# ## Load and Prepare Data
28+
# Load the Adult Income dataset. This dataset contains demographic information
29+
# and the task is to predict whether a person makes over $50K per year.
30+
31+
# Load data with header and rename columns to the expected symbols
32+
df = CSV.read("./adult.csv", DataFrame; header = true)
33+
rename!(
34+
df,
35+
[
36+
:age,
37+
:workclass,
38+
:fnlwgt,
39+
:education,
40+
:education_num,
41+
:marital_status,
42+
:occupation,
43+
:relationship,
44+
:race,
45+
:sex,
46+
:capital_gain,
47+
:capital_loss,
48+
:hours_per_week,
49+
:native_country,
50+
:income,
51+
],
52+
)
53+
54+
first(df, 5)
55+
56+
57+
# Clean the data by removing leading/trailing spaces and converting income to binary:
58+
for col in [:workclass, :education, :marital_status, :occupation, :relationship,
59+
:race, :sex, :native_country, :income]
60+
df[!, col] = strip.(string.(df[!, col]))
61+
end
62+
63+
# Convert income to binary (0 for <=50K, 1 for >50K)
64+
df.income = ifelse.(df.income .== ">50K", 1, 0)
65+
66+
# Let's a high-cardinality categorical feature to showcase encoder handling
67+
# Create a realistic frequency distribution: A1-A3 make up 90% of data, A4-A500 make up 10%
68+
Random.seed!(42)
69+
high_card_categories = ["A$i" for i in 1:500]
70+
71+
n_rows = nrow(df)
72+
n_frequent = Int(round(0.9 * n_rows)) # 90% for A1, A2, A3
73+
n_rare = n_rows - n_frequent # 10% for A4-A500
74+
75+
frequent_samples = rand(["A1", "A2", "A3"], n_frequent)
76+
77+
rare_categories = ["A$i" for i in 4:500]
78+
rare_samples = rand(rare_categories, n_rare)
79+
80+
# Combine and shuffle
81+
all_samples = vcat(frequent_samples, rare_samples)
82+
df.high_cardinality_feature = all_samples[randperm(n_rows)]
83+
84+
# Coerce categorical columns to appropriate scientific types.
85+
# Apply explicit type coercions using fully qualified names
86+
type_dict = Dict(
87+
:income => OrderedFactor,
88+
:age => Continuous,
89+
:fnlwgt => Continuous,
90+
:education_num => Continuous,
91+
:capital_gain => Continuous,
92+
:capital_loss => Continuous,
93+
:hours_per_week => Continuous,
94+
:workclass => Multiclass,
95+
:education => Multiclass,
96+
:marital_status => Multiclass,
97+
:occupation => Multiclass,
98+
:relationship => Multiclass,
99+
:race => Multiclass,
100+
:sex => Multiclass,
101+
:native_country => Multiclass,
102+
:high_cardinality_feature => Multiclass,
103+
)
104+
df = coerce(df, type_dict)
105+
106+
# Let's examine the cardinality of our categorical features:
107+
categorical_cols = [:workclass, :education, :marital_status, :occupation,
108+
:relationship, :race, :sex, :native_country, :high_cardinality_feature]
109+
println("Cardinality of categorical features:")
110+
for col in categorical_cols
111+
n_unique = length(unique(df[!, col]))
112+
println(" $col: $n_unique unique values")
113+
end
114+
115+
116+
117+
# ## Split Data
118+
# Separate features (X) from target (y), then split into train/test sets:
119+
120+
y, X = unpack(df, ==(:income); rng = 123);
121+
train, test = partition(eachindex(y), 0.8, shuffle = true, rng = 100);
122+
123+
# ## Setup Encoders and Model
124+
# Load the required models and create different encoding strategies:
125+
126+
OneHot = @load OneHotEncoder pkg = MLJModels verbosity = 0
127+
CatBoostClassifier = @load CatBoostClassifier pkg = CatBoost
128+
129+
130+
# **Encoding Strategies:**
131+
# 1. **OneHotEncoder**: Creates binary columns for each category
132+
# 2. **FrequencyEncoder**: Replaces categories with their frequency counts
133+
# In case of the one-hot-encoder, we worry when categories have high cardinality as that would lead to an explosion in the number of features.
134+
135+
card_reducer = MLJTransforms.CardinalityReducer(
136+
min_frequency = 0.15,
137+
ordered_factor = true,
138+
label_for_infrequent = Dict(
139+
AbstractString => "OtherItems",
140+
Char => 'O',
141+
),
142+
)
143+
onehot_model = OneHot(drop_last = true, ordered_factor = true)
144+
freq_model = MLJTransforms.FrequencyEncoder(normalize = false, ordered_factor = true)
145+
cat = CatBoostClassifier();
146+
147+
# Create three different pipelines to compare:
148+
pipelines = [
149+
("CardRed + OneHot + CAT", card_reducer |> onehot_model |> cat),
150+
("OneHot + CAT", onehot_model |> cat),
151+
("FreqEnc + CAT", freq_model |> cat),
152+
]
153+
154+
# ## Evaluate Pipelines with Proper Benchmarking
155+
# Train each pipeline and measure both performance (accuracy) and training time using @btime:
156+
157+
results = DataFrame(pipeline = String[], accuracy = Float64[], training_time = Float64[]);
158+
159+
# Prepare results DataFrame
160+
161+
for (name, pipe) in pipelines
162+
println("Training and benchmarking: $name")
163+
164+
## Train once to compute accuracy
165+
mach = machine(pipe, X, y)
166+
MLJ.fit!(mach, rows = train)
167+
predictions = MLJ.predict_mode(mach, rows = test)
168+
accuracy_value = MLJ.accuracy(predictions, y[test])
169+
170+
## Measure training time using @belapsed (returns Float64 seconds) with 5 samples
171+
## Create a fresh machine inside the benchmark to avoid state sharing
172+
training_time =
173+
@belapsed MLJ.fit!(machine($pipe, $X, $y), rows = $train, force = true) samples = 5
174+
175+
println(" Training time (min over 5 samples): $(training_time) s")
176+
println(" Accuracy: $(round(accuracy_value, digits=4))\n")
177+
178+
push!(results, (string(name), accuracy_value, training_time))
179+
end
180+
181+
182+
# Sort by accuracy (higher is better) and display results:
183+
sort!(results, :accuracy, rev = true)
184+
results
185+
186+
# ## Visualization
187+
# Create side-by-side bar charts to compare both training time and model performance:
188+
189+
n = nrow(results)
190+
191+
# Create a simple timing visualization (note: timing strings from @btime need manual parsing for plotting)
192+
# Sort by accuracy (higher is better)
193+
sort!(results, :accuracy, rev = true)
194+
results # show table
195+
196+
# -------------------------
197+
# Visualization (side-by-side)
198+
# -------------------------
199+
n = nrow(results)
200+
# training time plot (seconds)
201+
time_plot = bar(1:n, results.training_time;
202+
xticks = (1:n, results.pipeline),
203+
title = "Training Time (s)",
204+
xlabel = "Pipeline", ylabel = "Time (s)",
205+
xrotation = 8,
206+
legend = false,
207+
color = :lightblue,
208+
)
209+
210+
# accuracy plot
211+
accuracy_plot = bar(1:n, results.accuracy;
212+
xticks = (1:n, results.pipeline),
213+
title = "Classification Accuracy",
214+
xlabel = "Pipeline", ylabel = "Accuracy",
215+
xrotation = 8,
216+
legend = false,
217+
ylim = (0.0, 1.0),
218+
color = :lightcoral,
219+
)
220+
221+
222+
combined_plot = plot(time_plot, accuracy_plot; layout = (1, 2), size = (1200, 500))
223+
224+
# Save the plot
225+
savefig(combined_plot, "adult_encoding_comparison.png"); #hide
226+
227+
#md # ![Adult Encoding Comparison](adult_encoding_comparison_proper_benchmark.png)
228+
229+
# ## Conclusion
230+
#
231+
# **Key Findings from Results:**
232+
#
233+
# **Training Time Performance (dramatic differences!):**
234+
# - **FreqEnc + CAT**: 0.32 seconds - **fastest approach**
235+
# - **CardRed + OneHot + CAT**: 0.57 seconds - **10x faster than pure OneHot**
236+
# - **OneHot + CAT**: 5.85 seconds - **significantly slower due to high cardinality**
237+
#
238+
# **Accuracy:** In this example, we don't see a difference in accuracy but the savings in time are big.
239+
240+
# Note that we still observe a speed improvement with the cardinality reducer if we omit the high cardinality feature we added but it's much smaller as the adults dataset is not that high in cardinality.

0 commit comments

Comments
 (0)