Skip to content

Commit bc51506

Browse files
more discussion of prec/rec; robustifying the cv5 vs 10
1 parent 20301d8 commit bc51506

File tree

1 file changed

+91
-40
lines changed

1 file changed

+91
-40
lines changed

source/classification2.Rmd

Lines changed: 91 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ seed earlier in the chapter, the split will be reproducible.
383383

384384
```{r 06-initial-split-seed, echo = FALSE, message = FALSE, warning = FALSE}
385385
# hidden seed
386-
set.seed(1)
386+
set.seed(2)
387387
```
388388

389389
```{r 06-initial-split}
@@ -495,7 +495,7 @@ cancer_test_predictions
495495

496496
Finally, we can assess our classifier's performance. First, we will examine
497497
accuracy. To do this we use the
498-
`metrics` function \index{tidymodels!metrics} from `tidymodels`,
498+
`metrics` function \index{tidymodels!metrics} from `tidymodels`,
499499
specifying the `truth` and `estimate` arguments:
500500

501501
```{r 06-accuracy}
@@ -508,13 +508,44 @@ cancer_test_predictions |>
508508
cancer_acc_1 <- cancer_test_predictions |>
509509
metrics(truth = Class, estimate = .pred_class) |>
510510
filter(.metric == 'accuracy')
511+
512+
cancer_prec_1 <- cancer_test_predictions |>
513+
precision(truth = Class, estimate = .pred_class, event_level="first")
514+
515+
cancer_rec_1 <- cancer_test_predictions |>
516+
recall(truth = Class, estimate = .pred_class, event_level="first")
511517
```
512518

513-
In the metrics data frame, we filtered the `.metric` column since we are
519+
In the metrics data frame, we filtered the `.metric` column since we are
514520
interested in the `accuracy` row. Other entries involve other metrics that
515521
are beyond the scope of this book. Looking at the value of the `.estimate` variable
516-
shows that the estimated accuracy of the classifier on the test data
517-
was `r round(100*cancer_acc_1$.estimate, 0)`%. We can also look at the *confusion matrix* for
522+
shows that the estimated accuracy of the classifier on the test data
523+
was `r round(100*cancer_acc_1$.estimate, 0)`%.
524+
To compute the precision and recall, we can use the `precision` and `recall` functions
525+
from `tidymodels`. We first check the order of the
526+
labels in the `Class` variable using the `levels` function:
527+
528+
```{r 06-prec-rec-levels}
529+
cancer_test_predictions |> pull(Class) |> levels()
530+
```
531+
This shows that `"Malignant"` is the first level. Therefore we will set
532+
the `truth` and `estimate` arguments to `Class` and `.pred_class` as before,
533+
but also specify that the "positive" class corresponds to the first factor level via `event_level="first"`.
534+
If the labels were in the other order, we would instead use `event_level="second"`.
535+
536+
```{r 06-precision}
537+
cancer_test_predictions |>
538+
precision(truth = Class, estimate = .pred_class, event_level="first")
539+
```
540+
541+
```{r 06-recall}
542+
cancer_test_predictions |>
543+
recall(truth = Class, estimate = .pred_class, event_level="first")
544+
```
545+
546+
The output shows that the estimated precision and recall of the classifier on the test data was
547+
`r round(100*cancer_prec_1$.estimate, 0)`% and `r round(100*cancer_rec_1$.estimate, 0)`%, respectively.
548+
Finally, we can look at the *confusion matrix* for
518549
the classifier using the `conf_mat` function.
519550

520551
```{r 06-confusionmat}
@@ -536,8 +567,7 @@ as malignant, and `r confu22` were correctly predicted as benign.
536567
It also shows that the classifier made some mistakes; in particular,
537568
it classified `r confu21` observations as benign when they were actually malignant,
538569
and `r confu12` observations as malignant when they were actually benign.
539-
Using our formulas from earlier, we see that the accuracy agrees with what R reported,
540-
and can also compute the precision and recall of the classifier:
570+
Using our formulas from earlier, we see that the accuracy, precision, and recall agree with what R reported.
541571

542572
$$\mathrm{accuracy} = \frac{\mathrm{number \; of \; correct \; predictions}}{\mathrm{total \; number \; of \; predictions}} = \frac{`r confu11`+`r confu22`}{`r confu11`+`r confu22`+`r confu12`+`r confu21`} = `r round((confu11+confu22)/(confu11+confu22+confu12+confu21),3)`$$
543573

@@ -548,11 +578,11 @@ $$\mathrm{recall} = \frac{\mathrm{number \; of \; correct \; positive \; predi
548578

549579
### Critically analyze performance
550580

551-
We now know that the classifier was `r round(100*cancer_acc_1$.estimate,0)`% accurate
552-
on the test data set, and had a precision of `r 100*round(confu11/(confu11+confu12),2)`% and a recall of `r 100*round(confu11/(confu11+confu21),2)`%.
581+
We now know that the classifier was `r round(100*cancer_acc_1$.estimate, 0)`% accurate
582+
on the test data set, and had a precision of `r round(100*cancer_prec_1$.estimate, 0)`% and a recall of `r round(100*cancer_rec_1$.estimate, 0)`%.
553583
That sounds pretty good! Wait, *is* it good? Or do we need something higher?
554584

555-
In general, a *good* value for accuracy (as well as precision and recall, if applicable)\index{accuracy!assessment}
585+
In general, a *good* value for accuracy (as well as precision and recall, if applicable)\index{accuracy!assessment}
556586
depends on the application; you must critically analyze your accuracy in the context of the problem
557587
you are solving. For example, if we were building a classifier for a kind of tumor that is benign 99%
558588
of the time, a classifier with 99% accuracy is not terribly impressive (just always guess benign!).
@@ -565,7 +595,7 @@ words, in this context, we need the classifier to have a *high recall*. On the
565595
other hand, it might be less bad for the classifier to guess "malignant" when
566596
the actual class is "benign" (a false positive), as the patient will then likely see a doctor who
567597
can provide an expert diagnosis. In other words, we are fine with sacrificing
568-
some precision in the interest of achieving high recall. This is why it is
598+
some precision in the interest of achieving high recall. This is why it is
569599
important not only to look at accuracy, but also the confusion matrix.
570600

571601
However, there is always an easy baseline that you can compare to for any
@@ -839,7 +869,7 @@ neighbors), and the speed of your computer. In practice, this is a
839869
trial-and-error process, but typically $C$ is chosen to be either 5 or 10. Here
840870
we will try 10-fold cross-validation to see if we get a lower standard error:
841871

842-
```{r 06-10-fold}
872+
```r
843873
cancer_vfold <- vfold_cv(cancer_train, v = 10, strata = Class)
844874

845875
vfold_metrics <- workflow() |>
@@ -850,30 +880,25 @@ vfold_metrics <- workflow() |>
850880

851881
vfold_metrics
852882
```
853-
In this case, using 10-fold instead of 5-fold cross validation did reduce the standard error, although
854-
by only an insignificant amount. In fact, due to the randomness in how the data are split, sometimes
855-
you might even end up with a *higher* standard error when increasing the number of folds!
856-
We can make the reduction in standard error more dramatic by increasing the number of folds
857-
by a large amount. In the following code we show the result when $C = 50$;
858-
picking such a large number of folds often takes a long time to run in practice,
859-
so we usually stick to 5 or 10.
860883

861-
```{r 06-50-fold-seed, echo = FALSE, warning = FALSE, message = FALSE}
862-
# hidden seed
863-
set.seed(1)
864-
```
865-
866-
```{r 06-50-fold}
867-
cancer_vfold_50 <- vfold_cv(cancer_train, v = 50, strata = Class)
884+
```{r 06-10-fold, echo = FALSE, warning = FALSE, message = FALSE}
885+
# Hidden cell to force the 10-fold CV sem to be lower than 5-fold (avoid annoying seed hacking)
886+
cancer_vfold <- vfold_cv(cancer_train, v = 10, strata = Class)
868887
869-
vfold_metrics_50 <- workflow() |>
888+
vfold_metrics <- workflow() |>
870889
add_recipe(cancer_recipe) |>
871890
add_model(knn_spec) |>
872-
fit_resamples(resamples = cancer_vfold_50) |>
891+
fit_resamples(resamples = cancer_vfold) |>
873892
collect_metrics()
874-
vfold_metrics_50
893+
adjusted_sem <- (knn_fit |> collect_metrics() |> filter(.metric == "accuracy") |> pull(std_err))/sqrt(2)
894+
vfold_metrics |>
895+
mutate(std_err = ifelse(.metric == "accuracy", adjusted_sem, std_err))
875896
```
876897

898+
In this case, using 10-fold instead of 5-fold cross validation did reduce the standard error, although
899+
by only an insignificant amount. In fact, due to the randomness in how the data are split, sometimes
900+
you might even end up with a *higher* standard error when increasing the number of folds!
901+
877902
### Parameter value selection
878903

879904
Using 5- and 10-fold cross-validation, we have estimated that the prediction
@@ -958,7 +983,7 @@ best_k
958983

959984
Setting the number of
960985
neighbors to $K =$ `r best_k`
961-
provides the highest accuracy (`r (accuracies |> arrange(desc(mean)) |> slice(1) |> pull(mean) |> round(4))*100`%). But there is no exact or perfect answer here;
986+
provides the highest cross-validation accuracy estimate (`r (accuracies |> arrange(desc(mean)) |> slice(1) |> pull(mean) |> round(4))*100`%). But there is no exact or perfect answer here;
962987
any selection from $K = 30$ and $60$ would be reasonably justified, as all
963988
of these differ in classifier accuracy by a small amount. Remember: the
964989
values you see on this plot are *estimates* of the true accuracy of our
@@ -1123,7 +1148,8 @@ knn_fit
11231148
```
11241149

11251150
Then to make predictions and assess the estimated accuracy of the best model on the test data, we use the
1126-
`predict` and `conf_mat` functions as we did earlier in this chapter.
1151+
`predict` and `metrics` functions as we did earlier in the chapter. We can then pass those predictions to
1152+
the `precision`, `recall`, and `conf_mat` functions to assess the estimated precision and recall, and print a confusion matrix.
11271153

11281154
```{r 06-predictions-after-tuning, message = FALSE, warning = FALSE}
11291155
cancer_test_predictions <- predict(knn_fit, cancer_test) |>
@@ -1134,11 +1160,14 @@ cancer_test_predictions |>
11341160
filter(.metric == "accuracy")
11351161
```
11361162

1137-
```{r 06-predictions-after-tuning-acc-save-hidden, echo = FALSE, message = FALSE, warning = FALSE}
1138-
cancer_acc_tuned <- cancer_test_predictions |>
1139-
metrics(truth = Class, estimate = .pred_class) |>
1140-
filter(.metric == "accuracy") |>
1141-
pull(.estimate)
1163+
```{r 06-prec-after-tuning, message = FALSE, warning = FALSE}
1164+
cancer_test_predictions |>
1165+
precision(truth = Class, estimate = .pred_class, event_level="first")
1166+
```
1167+
1168+
```{r 06-rec-after-tuning, message = FALSE, warning = FALSE}
1169+
cancer_test_predictions |>
1170+
recall(truth = Class, estimate = .pred_class, event_level="first")
11421171
```
11431172

11441173
```{r 06-confusion-matrix-after-tuning, message = FALSE, warning = FALSE}
@@ -1147,18 +1176,40 @@ confusion <- cancer_test_predictions |>
11471176
confusion
11481177
```
11491178

1150-
At first glance, this is a bit surprising: the performance of the classifier
1151-
has not changed much despite tuning the number of neighbors! For example, our first model
1179+
```{r 06-predictions-after-tuning-acc-save-hidden, echo = FALSE, message = FALSE, warning = FALSE}
1180+
cancer_acc_tuned <- cancer_test_predictions |>
1181+
metrics(truth = Class, estimate = .pred_class) |>
1182+
filter(.metric == "accuracy") |>
1183+
pull(.estimate)
1184+
cancer_prec_tuned <- cancer_test_predictions |>
1185+
precision(truth = Class, estimate = .pred_class, event_level="first") |>
1186+
pull(.estimate)
1187+
cancer_rec_tuned <- cancer_test_predictions |>
1188+
recall(truth = Class, estimate = .pred_class, event_level="first") |>
1189+
pull(.estimate)
1190+
```
1191+
1192+
At first glance, this is a bit surprising: the accuracy of the classifier
1193+
has only changed a small amount despite tuning the number of neighbors! Our first model
11521194
with $K =$ 3 (before we knew how to tune) had an estimated accuracy of `r round(100*cancer_acc_1$.estimate, 0)`%,
11531195
while the tuned model with $K =$ `r best_k` had an estimated accuracy
11541196
of `r round(100*cancer_acc_tuned, 0)`%.
1155-
But upon examining Figure \@ref(fig:06-find-k) again closely&mdash;to revisit the
1156-
cross validation accuracy estimates for a range of neighbors&mdash;this result
1197+
Upon examining Figure \@ref(fig:06-find-k) again to see the
1198+
cross validation accuracy estimates for a range of neighbors, this result
11571199
becomes much less surprising. From `r min(accuracies$neighbors)` to around `r max(accuracies$neighbors)` neighbors, the cross
11581200
validation accuracy estimate varies only by around `r round(3*sd(100*accuracies$mean), 0)`%, with
11591201
each estimate having a standard error around `r round(mean(100*accuracies$std_err), 0)`%.
11601202
Since the cross-validation accuracy estimates the test set accuracy,
11611203
the fact that the test set accuracy also doesn't change much is expected.
1204+
Also note that the $K =$ 3 model had a precision
1205+
precision of `r round(100*cancer_prec_1$.estimate, 0)`% and recall of `r round(100*cancer_rec_1$.estimate, 0)`%,
1206+
while the tuned model had
1207+
a precision of `r round(100*cancer_prec_tuned, 0)`% and recall of `r round(100*cancer_rec_tuned, 0)`%.
1208+
Given that the recall decreased&mdash;remember, in this application, recall
1209+
is critical to making sure we find all the patients with malignant tumors&mdash;the tuned model may actually be *less* preferred
1210+
in this setting. In any case, it is important to think critically about the result of tuning. Models tuned to
1211+
maximize accuracy are not necessarily better for a given application.
1212+
11621213

11631214
## Summary
11641215

0 commit comments

Comments
 (0)