Skip to content

Commit 09de647

Browse files
committed
✨ Add CV analysis
1 parent 9f9d28e commit 09de647

File tree

3 files changed

+177
-25
lines changed

3 files changed

+177
-25
lines changed
24.8 KB
Loading

docs/src/tutorials/classic_comparison/notebook.jl

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Pkg.activate(@__DIR__);
1414
Pkg.instantiate(); #src
1515

1616
using MLJ, MLJTransforms, LIBSVM, DataFrames, ScientificTypes
17-
using Random, CSV
17+
using Random, CSV, Plots
1818

1919
# ## Load and Prepare Data
2020
# Load the milk quality dataset which contains categorical features for quality prediction:
@@ -70,28 +70,89 @@ pipelines = [
7070
# ## Evaluate Pipelines
7171
# Use 10-fold cross-validation to robustly estimate each pipeline's accuracy:
7272

73-
results = DataFrame(pipeline = String[], accuracy = Float64[])
73+
results = DataFrame(
74+
pipeline = String[],
75+
accuracy = Float64[],
76+
std_error = Float64[],
77+
ci_lower = Float64[],
78+
ci_upper = Float64[],
79+
)
7480

7581
for (name, pipe) in pipelines
7682
println("Evaluating: $name")
77-
mach = machine(pipe, X, y)
78-
eval_results = evaluate!(
79-
mach,
80-
resampling = CV(nfolds = 10, rng = 123),
83+
eval_results = evaluate(
84+
pipe,
85+
X,
86+
y,
87+
resampling = CV(nfolds = 5, rng = 123),
8188
measure = accuracy,
8289
rows = train,
8390
verbosity = 0,
8491
)
85-
acc = mean(eval_results.measurement)
86-
push!(results, (name, acc))
92+
acc = eval_results.measurement[1] # scalar mean
93+
per_fold = eval_results.per_fold[1] # vector of fold results
94+
se = std(per_fold) / sqrt(length(per_fold))
95+
ci = 1.96 * se
96+
push!(
97+
results,
98+
(
99+
pipeline = name,
100+
accuracy = acc,
101+
std_error = se,
102+
ci_lower = acc - ci,
103+
ci_upper = acc + ci,
104+
),
105+
)
106+
println(" Mean accuracy: $(round(acc, digits=4)) ± $(round(ci, digits=4))")
87107
end
88108

89109
# Sort results by accuracy (highest first) and display:
90110
sort!(results, :accuracy, rev = true)
111+
112+
# Display results with confidence intervals
113+
println("\nResults with 95% Confidence Intervals (see caveats below):")
114+
println("="^60)
115+
for row in eachrow(results)
116+
pipeline = row.pipeline
117+
acc = round(row.accuracy, digits = 4)
118+
ci_lower = round(row.ci_lower, digits = 4)
119+
ci_upper = round(row.ci_upper, digits = 4)
120+
println("$pipeline: $acc (95% CI: [$ci_lower, $ci_upper])")
121+
end
122+
91123
results
92124

93125
# ## Results Analysis
94-
# We notice that one-hot-encoding was the most performant here followed by target encoding.
95-
# Ordinal encoding also produced decent results because we can perceive all the categorical variables to be ordered
96-
# On the other hand, frequency encoding lagged behind. Observe that this method doesn't distinguish categories from one another if they occur with similar frequencies.
97-
#
126+
#
127+
# ### Performance Summary
128+
# The results show OneHot encoding performing best, followed by Target encoding, with Ordinal and Frequency encoders showing lower performance.
129+
#
130+
# The confidence intervals should be interpreted with caution and primarily serve to illustrate uncertainty rather than provide definitive statistical significance tests.
131+
# See Bengio & Grandvalet, 2004: "No Unbiased Estimator of the Variance of K-Fold Cross-Validation"). That said, reporting the interval is still more informative than reporting only the mean.
132+
133+
# Prepare data for plotting
134+
labels = results.pipeline
135+
mean_acc = results.accuracy
136+
ci_lower = results.ci_lower
137+
ci_upper = results.ci_upper
138+
139+
# Error bars: distance from mean to CI bounds
140+
lower_err = mean_acc .- ci_lower
141+
upper_err = ci_upper .- mean_acc
142+
143+
bar(
144+
labels,
145+
mean_acc,
146+
yerror = (lower_err, upper_err),
147+
legend = false,
148+
xlabel = "Encoder + SVM",
149+
ylabel = "Accuracy",
150+
title = "Mean Accuracy with 95% Confidence Intervals",
151+
ylim = (0, 1.05),
152+
color = :skyblue,
153+
size = (700, 400),
154+
);
155+
156+
# save the figure and load it
157+
savefig("encoder_comparison.png");
158+
# ![`encoder_comparison.png`](encoder_comparison.png)

0 commit comments

Comments
 (0)