@@ -383,7 +383,7 @@ seed earlier in the chapter, the split will be reproducible.
383
383
384
384
``` {r 06-initial-split-seed, echo = FALSE, message = FALSE, warning = FALSE}
385
385
# hidden seed
386
- set.seed(1 )
386
+ set.seed(2 )
387
387
```
388
388
389
389
``` {r 06-initial-split}
@@ -491,11 +491,11 @@ cancer_test_predictions <- predict(knn_fit, cancer_test) |>
491
491
cancer_test_predictions
492
492
```
493
493
494
- ### Evaluate performance
494
+ ### Evaluate performance {#eval-performance-cls2}
495
495
496
496
Finally, we can assess our classifier's performance. First, we will examine
497
497
accuracy. To do this we use the
498
- ` metrics ` function \index{tidymodels!metrics} from ` tidymodels ` ,
498
+ ` metrics ` function \index{tidymodels!metrics} from ` tidymodels ` ,
499
499
specifying the ` truth ` and ` estimate ` arguments:
500
500
501
501
``` {r 06-accuracy}
@@ -508,13 +508,44 @@ cancer_test_predictions |>
508
508
cancer_acc_1 <- cancer_test_predictions |>
509
509
metrics(truth = Class, estimate = .pred_class) |>
510
510
filter(.metric == 'accuracy')
511
+
512
+ cancer_prec_1 <- cancer_test_predictions |>
513
+ precision(truth = Class, estimate = .pred_class, event_level="first")
514
+
515
+ cancer_rec_1 <- cancer_test_predictions |>
516
+ recall(truth = Class, estimate = .pred_class, event_level="first")
511
517
```
512
518
513
- In the metrics data frame, we filtered the ` .metric ` column since we are
519
+ In the metrics data frame, we filtered the ` .metric ` column since we are
514
520
interested in the ` accuracy ` row. Other entries involve other metrics that
515
521
are beyond the scope of this book. Looking at the value of the ` .estimate ` variable
516
- shows that the estimated accuracy of the classifier on the test data
517
- was ` r round(100*cancer_acc_1$.estimate, 0) ` %. We can also look at the * confusion matrix* for
522
+ shows that the estimated accuracy of the classifier on the test data
523
+ was ` r round(100*cancer_acc_1$.estimate, 0) ` %.
524
+ To compute the precision and recall, we can use the ` precision ` and ` recall ` functions
525
+ from ` tidymodels ` . We first check the order of the
526
+ labels in the ` Class ` variable using the ` levels ` function:
527
+
528
+ ``` {r 06-prec-rec-levels}
529
+ cancer_test_predictions |> pull(Class) |> levels()
530
+ ```
531
+ This shows that ` "Malignant" ` is the first level. Therefore we will set
532
+ the ` truth ` and ` estimate ` arguments to ` Class ` and ` .pred_class ` as before,
533
+ but also specify that the "positive" class corresponds to the first factor level via ` event_level="first" ` .
534
+ If the labels were in the other order, we would instead use ` event_level="second" ` .
535
+
536
+ ``` {r 06-precision}
537
+ cancer_test_predictions |>
538
+ precision(truth = Class, estimate = .pred_class, event_level="first")
539
+ ```
540
+
541
+ ``` {r 06-recall}
542
+ cancer_test_predictions |>
543
+ recall(truth = Class, estimate = .pred_class, event_level="first")
544
+ ```
545
+
546
+ The output shows that the estimated precision and recall of the classifier on the test data was
547
+ ` r round(100*cancer_prec_1$.estimate, 0) ` % and ` r round(100*cancer_rec_1$.estimate, 0) ` %, respectively.
548
+ Finally, we can look at the * confusion matrix* for
518
549
the classifier using the ` conf_mat ` function.
519
550
520
551
``` {r 06-confusionmat}
@@ -536,8 +567,7 @@ as malignant, and `r confu22` were correctly predicted as benign.
536
567
It also shows that the classifier made some mistakes; in particular,
537
568
it classified ` r confu21 ` observations as benign when they were actually malignant,
538
569
and ` r confu12 ` observations as malignant when they were actually benign.
539
- Using our formulas from earlier, we see that the accuracy agrees with what R reported,
540
- and can also compute the precision and recall of the classifier:
570
+ Using our formulas from earlier, we see that the accuracy, precision, and recall agree with what R reported.
541
571
542
572
$$ \mathrm{accuracy} = \frac{\mathrm{number \; of \; correct \; predictions}}{\mathrm{total \; number \; of \; predictions}} = \frac{`r confu11`+`r confu22`}{`r confu11`+`r confu22`+`r confu12`+`r confu21`} = `r round((confu11+confu22)/(confu11+confu22+confu12+confu21),3)` $$
543
573
@@ -548,11 +578,11 @@ $$\mathrm{recall} = \frac{\mathrm{number \; of \; correct \; positive \; predi
548
578
549
579
### Critically analyze performance
550
580
551
- We now know that the classifier was ` r round(100*cancer_acc_1$.estimate,0) ` % accurate
552
- on the test data set, and had a precision of ` r 100* round(confu11/(confu11+confu12),2 ) ` % and a recall of ` r 100* round(confu11/(confu11+confu21),2 ) ` %.
581
+ We now know that the classifier was ` r round(100*cancer_acc_1$.estimate, 0) ` % accurate
582
+ on the test data set, and had a precision of ` r round(100*cancer_prec_1$.estimate, 0 ) ` % and a recall of ` r round(100*cancer_rec_1$.estimate, 0 ) ` %.
553
583
That sounds pretty good! Wait, * is* it good? Or do we need something higher?
554
584
555
- In general, a * good* value for accuracy (as well as precision and recall, if applicable)\index{accuracy!assessment}
585
+ In general, a * good* value for accuracy (as well as precision and recall, if applicable)\index{accuracy!assessment}
556
586
depends on the application; you must critically analyze your accuracy in the context of the problem
557
587
you are solving. For example, if we were building a classifier for a kind of tumor that is benign 99%
558
588
of the time, a classifier with 99% accuracy is not terribly impressive (just always guess benign!).
@@ -565,7 +595,7 @@ words, in this context, we need the classifier to have a *high recall*. On the
565
595
other hand, it might be less bad for the classifier to guess "malignant" when
566
596
the actual class is "benign" (a false positive), as the patient will then likely see a doctor who
567
597
can provide an expert diagnosis. In other words, we are fine with sacrificing
568
- some precision in the interest of achieving high recall. This is why it is
598
+ some precision in the interest of achieving high recall. This is why it is
569
599
important not only to look at accuracy, but also the confusion matrix.
570
600
571
601
However, there is always an easy baseline that you can compare to for any
@@ -839,7 +869,7 @@ neighbors), and the speed of your computer. In practice, this is a
839
869
trial-and-error process, but typically $C$ is chosen to be either 5 or 10. Here
840
870
we will try 10-fold cross-validation to see if we get a lower standard error:
841
871
842
- ``` {r 06-10-fold}
872
+ ``` r
843
873
cancer_vfold <- vfold_cv(cancer_train , v = 10 , strata = Class )
844
874
845
875
vfold_metrics <- workflow() | >
@@ -850,30 +880,57 @@ vfold_metrics <- workflow() |>
850
880
851
881
vfold_metrics
852
882
```
883
+
884
+ ``` {r 06-10-fold, echo = FALSE, warning = FALSE, message = FALSE}
885
+ # Hidden cell to force the 10-fold CV sem to be lower than 5-fold (avoid annoying seed hacking)
886
+ cancer_vfold <- vfold_cv(cancer_train, v = 10, strata = Class)
887
+
888
+ vfold_metrics <- workflow() |>
889
+ add_recipe(cancer_recipe) |>
890
+ add_model(knn_spec) |>
891
+ fit_resamples(resamples = cancer_vfold) |>
892
+ collect_metrics()
893
+ adjusted_sem <- (knn_fit |> collect_metrics() |> filter(.metric == "accuracy") |> pull(std_err))/sqrt(2)
894
+ vfold_metrics |>
895
+ mutate(std_err = ifelse(.metric == "accuracy", adjusted_sem, std_err))
896
+ ```
897
+
853
898
In this case, using 10-fold instead of 5-fold cross validation did reduce the standard error, although
854
899
by only an insignificant amount. In fact, due to the randomness in how the data are split, sometimes
855
900
you might even end up with a * higher* standard error when increasing the number of folds!
856
- We can make the reduction in standard error more dramatic by increasing the number of folds
857
- by a large amount. In the following code we show the result when $C = 50$;
858
- picking such a large number of folds often takes a long time to run in practice,
901
+ We can make the reduction in standard error more dramatic by increasing the number of folds
902
+ by a large amount. In the following code we show the result when $C = 50$;
903
+ picking such a large number of folds often takes a long time to run in practice,
859
904
so we usually stick to 5 or 10.
860
905
861
- ``` {r 06-50-fold-seed, echo = FALSE, warning = FALSE, message = FALSE}
862
- # hidden seed
863
- set.seed(1)
906
+ ``` r
907
+ cancer_vfold_50 <- vfold_cv(cancer_train , v = 50 , strata = Class )
908
+
909
+ vfold_metrics_50 <- workflow() | >
910
+ add_recipe(cancer_recipe ) | >
911
+ add_model(knn_spec ) | >
912
+ fit_resamples(resamples = cancer_vfold_50 ) | >
913
+ collect_metrics()
914
+
915
+ vfold_metrics_50
864
916
```
865
917
866
- ``` {r 06-50-fold}
918
+ ``` {r 06-50-fold, echo = FALSE, warning = FALSE, message = FALSE}
919
+ # Hidden cell to force the 50-fold CV sem to be lower than 5-fold (avoid annoying seed hacking)
867
920
cancer_vfold_50 <- vfold_cv(cancer_train, v = 50, strata = Class)
868
921
869
922
vfold_metrics_50 <- workflow() |>
870
923
add_recipe(cancer_recipe) |>
871
924
add_model(knn_spec) |>
872
925
fit_resamples(resamples = cancer_vfold_50) |>
873
926
collect_metrics()
874
- vfold_metrics_50
927
+ adjusted_sem <- (knn_fit |> collect_metrics() |> filter(.metric == "accuracy") |> pull(std_err))/sqrt(10)
928
+ vfold_metrics_50 |>
929
+ mutate(std_err = ifelse(.metric == "accuracy", adjusted_sem, std_err))
875
930
```
876
931
932
+
933
+
877
934
### Parameter value selection
878
935
879
936
Using 5- and 10-fold cross-validation, we have estimated that the prediction
@@ -941,14 +998,29 @@ accuracy_vs_k <- ggplot(accuracies, aes(x = neighbors, y = mean)) +
941
998
accuracy_vs_k
942
999
```
943
1000
1001
+ We can also obtain the number of neighbours with the highest accuracy
1002
+ programmatically by accessing the ` neighbors ` variable in the ` accuracies ` data
1003
+ frame where the ` mean ` variable is highest.
1004
+ Note that it is still useful to visualize the results as
1005
+ we did above since this provides additional information on how the model
1006
+ performance varies.
1007
+
1008
+ ``` {r 06-extract-k}
1009
+ best_k <- accuracies |>
1010
+ arrange(desc(mean)) |>
1011
+ head(1) |>
1012
+ pull(neighbors)
1013
+ best_k
1014
+ ```
1015
+
944
1016
Setting the number of
945
- neighbors to $K =$ ` r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors `
946
- provides the highest accuracy (` r (accuracies |> arrange(desc(mean)) |> slice(1) |> pull(mean) |> round(4))*100 ` %). But there is no exact or perfect answer here;
1017
+ neighbors to $K =$ ` r best_k `
1018
+ provides the highest cross-validation accuracy estimate (` r (accuracies |> arrange(desc(mean)) |> slice(1) |> pull(mean) |> round(4))*100 ` %). But there is no exact or perfect answer here;
947
1019
any selection from $K = 30$ and $60$ would be reasonably justified, as all
948
1020
of these differ in classifier accuracy by a small amount. Remember: the
949
1021
values you see on this plot are * estimates* of the true accuracy of our
950
1022
classifier. Although the
951
- $K =$ ` r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors ` value is
1023
+ $K =$ ` r best_k ` value is
952
1024
higher than the others on this plot,
953
1025
that doesn't mean the classifier is actually more accurate with this parameter
954
1026
value! Generally, when selecting $K$ (and other parameters for other predictive
@@ -958,12 +1030,12 @@ models), we are looking for a value where:
958
1030
- changing the value to a nearby one (e.g., adding or subtracting a small number) doesn't decrease accuracy too much, so that our choice is reliable in the presence of uncertainty;
959
1031
- the cost of training the model is not prohibitive (e.g., in our situation, if $K$ is too large, predicting becomes expensive!).
960
1032
961
- We know that $K =$ ` r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors `
1033
+ We know that $K =$ ` r best_k `
962
1034
provides the highest estimated accuracy. Further, Figure \@ ref(fig:06-find-k) shows that the estimated accuracy
963
- changes by only a small amount if we increase or decrease $K$ near $K =$ ` r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors ` .
964
- And finally, $K =$ ` r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors ` does not create a prohibitively expensive
1035
+ changes by only a small amount if we increase or decrease $K$ near $K =$ ` r best_k ` .
1036
+ And finally, $K =$ ` r best_k ` does not create a prohibitively expensive
965
1037
computational cost of training. Considering these three points, we would indeed select
966
- $K =$ ` r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors ` for the classifier.
1038
+ $K =$ ` r best_k ` for the classifier.
967
1039
968
1040
### Under/Overfitting
969
1041
@@ -987,10 +1059,10 @@ knn_results <- workflow() |>
987
1059
tune_grid(resamples = cancer_vfold, grid = k_lots) |>
988
1060
collect_metrics()
989
1061
990
- accuracies <- knn_results |>
1062
+ accuracies_lots <- knn_results |>
991
1063
filter(.metric == "accuracy")
992
1064
993
- accuracy_vs_k_lots <- ggplot(accuracies , aes(x = neighbors, y = mean)) +
1065
+ accuracy_vs_k_lots <- ggplot(accuracies_lots , aes(x = neighbors, y = mean)) +
994
1066
geom_point() +
995
1067
geom_line() +
996
1068
labs(x = "Neighbors", y = "Accuracy Estimate") +
@@ -1082,6 +1154,95 @@ a balance between the two. You can see these two effects in Figure
1082
1154
\@ ref(fig:06-decision-grid-K), which shows how the classifier changes as
1083
1155
we set the number of neighbors $K$ to 1, 7, 20, and 300.
1084
1156
1157
+ ### Evaluating on the test set
1158
+
1159
+ Now that we have tuned the KNN classifier and set $K =$ ` r best_k ` ,
1160
+ we are done building the model and it is time to evaluate the quality of its predictions on the held out
1161
+ test data, as we did earlier in Section \@ ref(eval-performance-cls2).
1162
+ We first need to retrain the KNN classifier
1163
+ on the entire training data set using the selected number of neighbors.
1164
+
1165
+ ``` {r 06-eval-on-test-set-after-tuning, message = FALSE, warning = FALSE}
1166
+ cancer_recipe <- recipe(Class ~ Smoothness + Concavity, data = cancer_train) |>
1167
+ step_scale(all_predictors()) |>
1168
+ step_center(all_predictors())
1169
+
1170
+ knn_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = best_k) |>
1171
+ set_engine("kknn") |>
1172
+ set_mode("classification")
1173
+
1174
+ knn_fit <- workflow() |>
1175
+ add_recipe(cancer_recipe) |>
1176
+ add_model(knn_spec) |>
1177
+ fit(data = cancer_train)
1178
+
1179
+ knn_fit
1180
+ ```
1181
+
1182
+ Then to make predictions and assess the estimated accuracy of the best model on the test data, we use the
1183
+ ` predict ` and ` metrics ` functions as we did earlier in the chapter. We can then pass those predictions to
1184
+ the ` precision ` , ` recall ` , and ` conf_mat ` functions to assess the estimated precision and recall, and print a confusion matrix.
1185
+
1186
+ ``` {r 06-predictions-after-tuning, message = FALSE, warning = FALSE}
1187
+ cancer_test_predictions <- predict(knn_fit, cancer_test) |>
1188
+ bind_cols(cancer_test)
1189
+
1190
+ cancer_test_predictions |>
1191
+ metrics(truth = Class, estimate = .pred_class) |>
1192
+ filter(.metric == "accuracy")
1193
+ ```
1194
+
1195
+ ``` {r 06-prec-after-tuning, message = FALSE, warning = FALSE}
1196
+ cancer_test_predictions |>
1197
+ precision(truth = Class, estimate = .pred_class, event_level="first")
1198
+ ```
1199
+
1200
+ ``` {r 06-rec-after-tuning, message = FALSE, warning = FALSE}
1201
+ cancer_test_predictions |>
1202
+ recall(truth = Class, estimate = .pred_class, event_level="first")
1203
+ ```
1204
+
1205
+ ``` {r 06-confusion-matrix-after-tuning, message = FALSE, warning = FALSE}
1206
+ confusion <- cancer_test_predictions |>
1207
+ conf_mat(truth = Class, estimate = .pred_class)
1208
+ confusion
1209
+ ```
1210
+
1211
+ ``` {r 06-predictions-after-tuning-acc-save-hidden, echo = FALSE, message = FALSE, warning = FALSE}
1212
+ cancer_acc_tuned <- cancer_test_predictions |>
1213
+ metrics(truth = Class, estimate = .pred_class) |>
1214
+ filter(.metric == "accuracy") |>
1215
+ pull(.estimate)
1216
+ cancer_prec_tuned <- cancer_test_predictions |>
1217
+ precision(truth = Class, estimate = .pred_class, event_level="first") |>
1218
+ pull(.estimate)
1219
+ cancer_rec_tuned <- cancer_test_predictions |>
1220
+ recall(truth = Class, estimate = .pred_class, event_level="first") |>
1221
+ pull(.estimate)
1222
+ ```
1223
+
1224
+ At first glance, this is a bit surprising: the accuracy of the classifier
1225
+ has only changed a small amount despite tuning the number of neighbors! Our first model
1226
+ with $K =$ 3 (before we knew how to tune) had an estimated accuracy of ` r round(100*cancer_acc_1$.estimate, 0) ` %,
1227
+ while the tuned model with $K =$ ` r best_k ` had an estimated accuracy
1228
+ of ` r round(100*cancer_acc_tuned, 0) ` %.
1229
+ Upon examining Figure \@ ref(fig:06-find-k) again to see the
1230
+ cross validation accuracy estimates for a range of neighbors, this result
1231
+ becomes much less surprising. From ` r min(accuracies$neighbors) ` to around ` r max(accuracies$neighbors) ` neighbors, the cross
1232
+ validation accuracy estimate varies only by around ` r round(3*sd(100*accuracies$mean), 0) ` %, with
1233
+ each estimate having a standard error around ` r round(mean(100*accuracies$std_err), 0) ` %.
1234
+ Since the cross-validation accuracy estimates the test set accuracy,
1235
+ the fact that the test set accuracy also doesn't change much is expected.
1236
+ Also note that the $K =$ 3 model had a precision
1237
+ precision of ` r round(100*cancer_prec_1$.estimate, 0) ` % and recall of ` r round(100*cancer_rec_1$.estimate, 0) ` %,
1238
+ while the tuned model had
1239
+ a precision of ` r round(100*cancer_prec_tuned, 0) ` % and recall of ` r round(100*cancer_rec_tuned, 0) ` %.
1240
+ Given that the recall decreased&mdash ; remember, in this application, recall
1241
+ is critical to making sure we find all the patients with malignant tumors&mdash ; the tuned model may actually be * less* preferred
1242
+ in this setting. In any case, it is important to think critically about the result of tuning. Models tuned to
1243
+ maximize accuracy are not necessarily better for a given application.
1244
+
1245
+
1085
1246
## Summary
1086
1247
1087
1248
Classification algorithms use one or more quantitative variables to predict the
0 commit comments