Skip to content

Commit 8bd0dc9

Browse files
committed
✨ Improve standardization
1 parent f7ac80e commit 8bd0dc9

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

docs/src/tutorials/standardization/notebook.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ stand = Standardizer() # This is our standardization transformer
104104

105105
## Create pipelines for each model variant
106106
logreg_pipe = logreg() # Plain logistic regression
107-
logreg_std_pipe = Pipeline(stand, logreg()) # Logistic regression with standardization
107+
logreg_std_pipe = stand |> logreg() # Logistic regression with standardization
108108
svm_pipe = svm() # Plain SVM
109-
svm_std_pipe = Pipeline(stand, svm()) # SVM with standardization
109+
svm_std_pipe = stand |> svm() # SVM with standardization
110110

111111
# ## Model Evaluation
112112
#
@@ -123,6 +123,10 @@ models = [
123123

124124
# Now we'll loop through each model, train it, make predictions, and calculate accuracy.
125125
# This will help us compare how standardization affects each model's performance.
126+
#
127+
# Note: As an alternative to the explicit fit!/predict workflow below, we could use:
128+
# evaluate(model, X, y, resampling=[(train, test),], measure=accuracy)
129+
# This shortcut handles the training, prediction, and evaluation in one step.
126130

127131
## Train and evaluate each model
128132
results = DataFrame(model = String[], accuracy = Float64[])

docs/src/tutorials/standardization/notebook.md

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ first(df, 5)
6464
````
6565

6666
```@raw html
67-
<div><div style = "float: left;"><span>5×8 DataFrame</span></div><div style = "clear: both;"></div></div><div class = "data-frame" style = "overflow-x: scroll;"><table class = "data-frame" style = "margin-bottom: 6px;"><thead><tr class = "header"><th class = "rowNumber" style = "font-weight: bold; text-align: right;">Row</th><th style = "text-align: left;">NPreg</th><th style = "text-align: left;">Glu</th><th style = "text-align: left;">BP</th><th style = "text-align: left;">Skin</th><th style = "text-align: left;">BMI</th><th style = "text-align: left;">Ped</th><th style = "text-align: left;">Age</th><th style = "text-align: left;">Type</th></tr><tr class = "subheader headerLastRow"><th class = "rowNumber" style = "font-weight: bold; text-align: right;"></th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "Float64" style = "text-align: left;">Float64</th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "Float64" style = "text-align: left;">Float64</th><th title = "Float64" style = "text-align: left;">Float64</th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "CategoricalArrays.CategoricalValue{String, UInt8}" style = "text-align: left;">Cat…</th></tr></thead><tbody><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">1</td><td style = "text-align: right;">5</td><td style = "text-align: right;">10086.0</td><td style = "text-align: right;">68</td><td style = "text-align: right;">28</td><td style = "text-align: right;">30.2</td><td style = "text-align: right;">0.364</td><td style = "text-align: right;">24</td><td style = "text-align: left;">No</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">2</td><td style = "text-align: right;">7</td><td style = "text-align: right;">10195.0</td><td style = "text-align: right;">70</td><td style = "text-align: right;">33</td><td style = "text-align: right;">25.1</td><td style = "text-align: right;">0.163</td><td style = "text-align: right;">55</td><td style = "text-align: left;">Yes</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">3</td><td style = "text-align: right;">5</td><td style = "text-align: right;">10077.0</td><td style = "text-align: right;">82</td><td style = "text-align: right;">41</td><td style = "text-align: right;">35.8</td><td style = "text-align: right;">0.156</td><td style = "text-align: right;">35</td><td style = "text-align: left;">No</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">4</td><td style = "text-align: right;">0</td><td style = "text-align: right;">10165.0</td><td style = "text-align: right;">76</td><td style = "text-align: right;">43</td><td style = "text-align: right;">47.9</td><td style = "text-align: right;">0.259</td><td style = "text-align: right;">26</td><td style = "text-align: left;">No</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">5</td><td style = "text-align: right;">0</td><td style = "text-align: right;">10107.0</td><td style = "text-align: right;">60</td><td style = "text-align: right;">25</td><td style = "text-align: right;">26.4</td><td style = "text-align: right;">0.133</td><td style = "text-align: right;">23</td><td style = "text-align: left;">No</td></tr></tbody></table></div>
67+
<div><div style = "float: left;"><span>5×8 DataFrame</span></div><div style = "clear: both;"></div></div><div class = "data-frame" style = "overflow-x: scroll;"><table class = "data-frame" style = "margin-bottom: 6px;"><thead><tr class = "header"><th class = "rowNumber" style = "font-weight: bold; text-align: right;">Row</th><th style = "text-align: left;">NPreg</th><th style = "text-align: left;">Glu</th><th style = "text-align: left;">BP</th><th style = "text-align: left;">Skin</th><th style = "text-align: left;">BMI</th><th style = "text-align: left;">Ped</th><th style = "text-align: left;">Age</th><th style = "text-align: left;">Type</th></tr><tr class = "subheader headerLastRow"><th class = "rowNumber" style = "font-weight: bold; text-align: right;"></th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "Float64" style = "text-align: left;">Float64</th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "Float64" style = "text-align: left;">Float64</th><th title = "Float64" style = "text-align: left;">Float64</th><th title = "Int32" style = "text-align: left;">Int32</th><th title = "CategoricalValue{String, UInt8}" style = "text-align: left;">Cat…</th></tr></thead><tbody><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">1</td><td style = "text-align: right;">5</td><td style = "text-align: right;">10086.0</td><td style = "text-align: right;">68</td><td style = "text-align: right;">28</td><td style = "text-align: right;">30.2</td><td style = "text-align: right;">0.364</td><td style = "text-align: right;">24</td><td style = "text-align: left;">No</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">2</td><td style = "text-align: right;">7</td><td style = "text-align: right;">10195.0</td><td style = "text-align: right;">70</td><td style = "text-align: right;">33</td><td style = "text-align: right;">25.1</td><td style = "text-align: right;">0.163</td><td style = "text-align: right;">55</td><td style = "text-align: left;">Yes</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">3</td><td style = "text-align: right;">5</td><td style = "text-align: right;">10077.0</td><td style = "text-align: right;">82</td><td style = "text-align: right;">41</td><td style = "text-align: right;">35.8</td><td style = "text-align: right;">0.156</td><td style = "text-align: right;">35</td><td style = "text-align: left;">No</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">4</td><td style = "text-align: right;">0</td><td style = "text-align: right;">10165.0</td><td style = "text-align: right;">76</td><td style = "text-align: right;">43</td><td style = "text-align: right;">47.9</td><td style = "text-align: right;">0.259</td><td style = "text-align: right;">26</td><td style = "text-align: left;">No</td></tr><tr><td class = "rowNumber" style = "font-weight: bold; text-align: right;">5</td><td style = "text-align: right;">0</td><td style = "text-align: right;">10107.0</td><td style = "text-align: right;">60</td><td style = "text-align: right;">25</td><td style = "text-align: right;">26.4</td><td style = "text-align: right;">0.133</td><td style = "text-align: right;">23</td><td style = "text-align: left;">No</td></tr></tbody></table></div>
6868
```
6969

7070
### Data Type Conversion
@@ -151,9 +151,9 @@ stand = Standardizer() # This is our standardization transformer
151151

152152
# Create pipelines for each model variant
153153
logreg_pipe = logreg() # Plain logistic regression
154-
logreg_std_pipe = Pipeline(stand, logreg()) # Logistic regression with standardization
154+
logreg_std_pipe = stand |> logreg() # Logistic regression with standardization
155155
svm_pipe = svm() # Plain SVM
156-
svm_std_pipe = Pipeline(stand, svm()) # SVM with standardization
156+
svm_std_pipe = stand |> svm() # SVM with standardization
157157
````
158158

159159
````
@@ -191,7 +191,7 @@ models = [
191191
````
192192

193193
````
194-
4-element Vector{Tuple{String, MLJModelInterface.Supervised}}:
194+
4-element Vector{Tuple{String, Supervised}}:
195195
("Logistic Regression", LogisticClassifier(lambda = 2.220446049250313e-16, …))
196196
("Logistic Regression (standardized)", ProbabilisticPipeline(standardizer = Standardizer(features = Symbol[], …), …))
197197
("SVM", SVC(kernel = RadialBasis, …))
@@ -201,6 +201,10 @@ models = [
201201
Now we'll loop through each model, train it, make predictions, and calculate accuracy.
202202
This will help us compare how standardization affects each model's performance.
203203

204+
Note: As an alternative to the explicit fit!/predict workflow below, we could use:
205+
evaluate(model, X, y, resampling=[(train, test),], measure=accuracy)
206+
This shortcut handles the training, prediction, and evaluation in one step.
207+
204208
````julia
205209
# Train and evaluate each model
206210
results = DataFrame(model = String[], accuracy = Float64[])
@@ -243,10 +247,10 @@ end
243247
244248
│ In the present case:
245249
246-
│ scitype(data) = Tuple{ScientificTypesBase.Table{Union{AbstractVector{ScientificTypesBase.Continuous}, AbstractVector{ScientificTypesBase.Count}}}, AbstractVector{ScientificTypesBase.Multiclass{2}}}
250+
│ scitype(data) = Tuple{Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}, AbstractVector{Multiclass{2}}}
247251
248-
│ fit_data_scitype(model) = Tuple{ScientificTypesBase.Table{<:AbstractVector{<:ScientificTypesBase.Continuous}}, AbstractVector{<:ScientificTypesBase.Finite}}
249-
└ @ MLJBase ~/.julia/packages/MLJBase/F1Eh6/src/machines.jl:237
252+
│ fit_data_scitype(model) = Tuple{Table{<:AbstractVector{<:Continuous}}, AbstractVector{<:Finite}}
253+
└ @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:237
250254
[ Info: Training machine(LogisticClassifier(lambda = 2.220446049250313e-16, …), …).
251255
┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, @NamedTuple{}}
252256
│ optim_options: Optim.Options{Float64, Nothing}
@@ -273,10 +277,10 @@ end
273277
274278
│ In the present case:
275279
276-
│ scitype(data) = Tuple{ScientificTypesBase.Table{Union{AbstractVector{ScientificTypesBase.Continuous}, AbstractVector{ScientificTypesBase.Count}}}, AbstractVector{ScientificTypesBase.Multiclass{2}}}
280+
│ scitype(data) = Tuple{Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}, AbstractVector{Multiclass{2}}}
277281
278-
│ fit_data_scitype(model) = Union{Tuple{ScientificTypesBase.Table{<:AbstractVector{<:ScientificTypesBase.Continuous}}, AbstractVector{<:ScientificTypesBase.Finite}}, Tuple{ScientificTypesBase.Table{<:AbstractVector{<:ScientificTypesBase.Continuous}}, AbstractVector{<:ScientificTypesBase.Finite}, Any}}
279-
└ @ MLJBase ~/.julia/packages/MLJBase/F1Eh6/src/machines.jl:237
282+
│ fit_data_scitype(model) = Union{Tuple{Table{<:AbstractVector{<:Continuous}}, AbstractVector{<:Finite}}, Tuple{Table{<:AbstractVector{<:Continuous}}, AbstractVector{<:Finite}, Any}}
283+
└ @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:237
280284
[ Info: Training machine(SVC(kernel = RadialBasis, …), …).
281285
[ Info: Training machine(DeterministicPipeline(standardizer = Standardizer(features = Symbol[], …), …), …).
282286
[ Info: Training machine(:standardizer, …).

0 commit comments

Comments
 (0)