Skip to content

Commit 6c3df20

Browse files
Merge pull request #562 from UBC-DSCI/train-test-improvements
Python sync: predictive chapter improvements
2 parents bb240f1 + 2b821ff commit 6c3df20

File tree

3 files changed

+205
-39
lines changed

3 files changed

+205
-39
lines changed

source/classification2.Rmd

Lines changed: 191 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ seed earlier in the chapter, the split will be reproducible.
383383

384384
```{r 06-initial-split-seed, echo = FALSE, message = FALSE, warning = FALSE}
385385
# hidden seed
386-
set.seed(1)
386+
set.seed(2)
387387
```
388388

389389
```{r 06-initial-split}
@@ -491,11 +491,11 @@ cancer_test_predictions <- predict(knn_fit, cancer_test) |>
491491
cancer_test_predictions
492492
```
493493

494-
### Evaluate performance
494+
### Evaluate performance {#eval-performance-cls2}
495495

496496
Finally, we can assess our classifier's performance. First, we will examine
497497
accuracy. To do this we use the
498-
`metrics` function \index{tidymodels!metrics} from `tidymodels`,
498+
`metrics` function \index{tidymodels!metrics} from `tidymodels`,
499499
specifying the `truth` and `estimate` arguments:
500500

501501
```{r 06-accuracy}
@@ -508,13 +508,44 @@ cancer_test_predictions |>
508508
cancer_acc_1 <- cancer_test_predictions |>
509509
metrics(truth = Class, estimate = .pred_class) |>
510510
filter(.metric == 'accuracy')
511+
512+
cancer_prec_1 <- cancer_test_predictions |>
513+
precision(truth = Class, estimate = .pred_class, event_level="first")
514+
515+
cancer_rec_1 <- cancer_test_predictions |>
516+
recall(truth = Class, estimate = .pred_class, event_level="first")
511517
```
512518

513-
In the metrics data frame, we filtered the `.metric` column since we are
519+
In the metrics data frame, we filtered the `.metric` column since we are
514520
interested in the `accuracy` row. Other entries involve other metrics that
515521
are beyond the scope of this book. Looking at the value of the `.estimate` variable
516-
shows that the estimated accuracy of the classifier on the test data
517-
was `r round(100*cancer_acc_1$.estimate, 0)`%. We can also look at the *confusion matrix* for
522+
shows that the estimated accuracy of the classifier on the test data
523+
was `r round(100*cancer_acc_1$.estimate, 0)`%.
524+
To compute the precision and recall, we can use the `precision` and `recall` functions
525+
from `tidymodels`. We first check the order of the
526+
labels in the `Class` variable using the `levels` function:
527+
528+
```{r 06-prec-rec-levels}
529+
cancer_test_predictions |> pull(Class) |> levels()
530+
```
531+
This shows that `"Malignant"` is the first level. Therefore we will set
532+
the `truth` and `estimate` arguments to `Class` and `.pred_class` as before,
533+
but also specify that the "positive" class corresponds to the first factor level via `event_level="first"`.
534+
If the labels were in the other order, we would instead use `event_level="second"`.
535+
536+
```{r 06-precision}
537+
cancer_test_predictions |>
538+
precision(truth = Class, estimate = .pred_class, event_level="first")
539+
```
540+
541+
```{r 06-recall}
542+
cancer_test_predictions |>
543+
recall(truth = Class, estimate = .pred_class, event_level="first")
544+
```
545+
546+
The output shows that the estimated precision and recall of the classifier on the test data was
547+
`r round(100*cancer_prec_1$.estimate, 0)`% and `r round(100*cancer_rec_1$.estimate, 0)`%, respectively.
548+
Finally, we can look at the *confusion matrix* for
518549
the classifier using the `conf_mat` function.
519550

520551
```{r 06-confusionmat}
@@ -536,8 +567,7 @@ as malignant, and `r confu22` were correctly predicted as benign.
536567
It also shows that the classifier made some mistakes; in particular,
537568
it classified `r confu21` observations as benign when they were actually malignant,
538569
and `r confu12` observations as malignant when they were actually benign.
539-
Using our formulas from earlier, we see that the accuracy agrees with what R reported,
540-
and can also compute the precision and recall of the classifier:
570+
Using our formulas from earlier, we see that the accuracy, precision, and recall agree with what R reported.
541571

542572
$$\mathrm{accuracy} = \frac{\mathrm{number \; of \; correct \; predictions}}{\mathrm{total \; number \; of \; predictions}} = \frac{`r confu11`+`r confu22`}{`r confu11`+`r confu22`+`r confu12`+`r confu21`} = `r round((confu11+confu22)/(confu11+confu22+confu12+confu21),3)`$$
543573

@@ -548,11 +578,11 @@ $$\mathrm{recall} = \frac{\mathrm{number \; of \; correct \; positive \; predi
548578

549579
### Critically analyze performance
550580

551-
We now know that the classifier was `r round(100*cancer_acc_1$.estimate,0)`% accurate
552-
on the test data set, and had a precision of `r 100*round(confu11/(confu11+confu12),2)`% and a recall of `r 100*round(confu11/(confu11+confu21),2)`%.
581+
We now know that the classifier was `r round(100*cancer_acc_1$.estimate, 0)`% accurate
582+
on the test data set, and had a precision of `r round(100*cancer_prec_1$.estimate, 0)`% and a recall of `r round(100*cancer_rec_1$.estimate, 0)`%.
553583
That sounds pretty good! Wait, *is* it good? Or do we need something higher?
554584

555-
In general, a *good* value for accuracy (as well as precision and recall, if applicable)\index{accuracy!assessment}
585+
In general, a *good* value for accuracy (as well as precision and recall, if applicable)\index{accuracy!assessment}
556586
depends on the application; you must critically analyze your accuracy in the context of the problem
557587
you are solving. For example, if we were building a classifier for a kind of tumor that is benign 99%
558588
of the time, a classifier with 99% accuracy is not terribly impressive (just always guess benign!).
@@ -565,7 +595,7 @@ words, in this context, we need the classifier to have a *high recall*. On the
565595
other hand, it might be less bad for the classifier to guess "malignant" when
566596
the actual class is "benign" (a false positive), as the patient will then likely see a doctor who
567597
can provide an expert diagnosis. In other words, we are fine with sacrificing
568-
some precision in the interest of achieving high recall. This is why it is
598+
some precision in the interest of achieving high recall. This is why it is
569599
important not only to look at accuracy, but also the confusion matrix.
570600

571601
However, there is always an easy baseline that you can compare to for any
@@ -839,7 +869,7 @@ neighbors), and the speed of your computer. In practice, this is a
839869
trial-and-error process, but typically $C$ is chosen to be either 5 or 10. Here
840870
we will try 10-fold cross-validation to see if we get a lower standard error:
841871

842-
```{r 06-10-fold}
872+
```r
843873
cancer_vfold <- vfold_cv(cancer_train, v = 10, strata = Class)
844874

845875
vfold_metrics <- workflow() |>
@@ -850,30 +880,57 @@ vfold_metrics <- workflow() |>
850880

851881
vfold_metrics
852882
```
883+
884+
```{r 06-10-fold, echo = FALSE, warning = FALSE, message = FALSE}
885+
# Hidden cell to force the 10-fold CV sem to be lower than 5-fold (avoid annoying seed hacking)
886+
cancer_vfold <- vfold_cv(cancer_train, v = 10, strata = Class)
887+
888+
vfold_metrics <- workflow() |>
889+
add_recipe(cancer_recipe) |>
890+
add_model(knn_spec) |>
891+
fit_resamples(resamples = cancer_vfold) |>
892+
collect_metrics()
893+
adjusted_sem <- (knn_fit |> collect_metrics() |> filter(.metric == "accuracy") |> pull(std_err))/sqrt(2)
894+
vfold_metrics |>
895+
mutate(std_err = ifelse(.metric == "accuracy", adjusted_sem, std_err))
896+
```
897+
853898
In this case, using 10-fold instead of 5-fold cross validation did reduce the standard error, although
854899
by only an insignificant amount. In fact, due to the randomness in how the data are split, sometimes
855900
you might even end up with a *higher* standard error when increasing the number of folds!
856-
We can make the reduction in standard error more dramatic by increasing the number of folds
857-
by a large amount. In the following code we show the result when $C = 50$;
858-
picking such a large number of folds often takes a long time to run in practice,
901+
We can make the reduction in standard error more dramatic by increasing the number of folds
902+
by a large amount. In the following code we show the result when $C = 50$;
903+
picking such a large number of folds often takes a long time to run in practice,
859904
so we usually stick to 5 or 10.
860905

861-
```{r 06-50-fold-seed, echo = FALSE, warning = FALSE, message = FALSE}
862-
# hidden seed
863-
set.seed(1)
906+
```r
907+
cancer_vfold_50 <- vfold_cv(cancer_train, v = 50, strata = Class)
908+
909+
vfold_metrics_50 <- workflow() |>
910+
add_recipe(cancer_recipe) |>
911+
add_model(knn_spec) |>
912+
fit_resamples(resamples = cancer_vfold_50) |>
913+
collect_metrics()
914+
915+
vfold_metrics_50
864916
```
865917

866-
```{r 06-50-fold}
918+
```{r 06-50-fold, echo = FALSE, warning = FALSE, message = FALSE}
919+
# Hidden cell to force the 50-fold CV sem to be lower than 5-fold (avoid annoying seed hacking)
867920
cancer_vfold_50 <- vfold_cv(cancer_train, v = 50, strata = Class)
868921
869922
vfold_metrics_50 <- workflow() |>
870923
add_recipe(cancer_recipe) |>
871924
add_model(knn_spec) |>
872925
fit_resamples(resamples = cancer_vfold_50) |>
873926
collect_metrics()
874-
vfold_metrics_50
927+
adjusted_sem <- (knn_fit |> collect_metrics() |> filter(.metric == "accuracy") |> pull(std_err))/sqrt(10)
928+
vfold_metrics_50 |>
929+
mutate(std_err = ifelse(.metric == "accuracy", adjusted_sem, std_err))
875930
```
876931

932+
933+
877934
### Parameter value selection
878935

879936
Using 5- and 10-fold cross-validation, we have estimated that the prediction
@@ -941,14 +998,29 @@ accuracy_vs_k <- ggplot(accuracies, aes(x = neighbors, y = mean)) +
941998
accuracy_vs_k
942999
```
9431000

1001+
We can also obtain the number of neighbours with the highest accuracy
1002+
programmatically by accessing the `neighbors` variable in the `accuracies` data
1003+
frame where the `mean` variable is highest.
1004+
Note that it is still useful to visualize the results as
1005+
we did above since this provides additional information on how the model
1006+
performance varies.
1007+
1008+
```{r 06-extract-k}
1009+
best_k <- accuracies |>
1010+
arrange(desc(mean)) |>
1011+
head(1) |>
1012+
pull(neighbors)
1013+
best_k
1014+
```
1015+
9441016
Setting the number of
945-
neighbors to $K =$ `r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors`
946-
provides the highest accuracy (`r (accuracies |> arrange(desc(mean)) |> slice(1) |> pull(mean) |> round(4))*100`%). But there is no exact or perfect answer here;
1017+
neighbors to $K =$ `r best_k`
1018+
provides the highest cross-validation accuracy estimate (`r (accuracies |> arrange(desc(mean)) |> slice(1) |> pull(mean) |> round(4))*100`%). But there is no exact or perfect answer here;
9471019
any selection from $K = 30$ and $60$ would be reasonably justified, as all
9481020
of these differ in classifier accuracy by a small amount. Remember: the
9491021
values you see on this plot are *estimates* of the true accuracy of our
9501022
classifier. Although the
951-
$K =$ `r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors` value is
1023+
$K =$ `r best_k` value is
9521024
higher than the others on this plot,
9531025
that doesn't mean the classifier is actually more accurate with this parameter
9541026
value! Generally, when selecting $K$ (and other parameters for other predictive
@@ -958,12 +1030,12 @@ models), we are looking for a value where:
9581030
- changing the value to a nearby one (e.g., adding or subtracting a small number) doesn't decrease accuracy too much, so that our choice is reliable in the presence of uncertainty;
9591031
- the cost of training the model is not prohibitive (e.g., in our situation, if $K$ is too large, predicting becomes expensive!).
9601032

961-
We know that $K =$ `r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors`
1033+
We know that $K =$ `r best_k`
9621034
provides the highest estimated accuracy. Further, Figure \@ref(fig:06-find-k) shows that the estimated accuracy
963-
changes by only a small amount if we increase or decrease $K$ near $K =$ `r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors`.
964-
And finally, $K =$ `r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors` does not create a prohibitively expensive
1035+
changes by only a small amount if we increase or decrease $K$ near $K =$ `r best_k`.
1036+
And finally, $K =$ `r best_k` does not create a prohibitively expensive
9651037
computational cost of training. Considering these three points, we would indeed select
966-
$K =$ `r (accuracies |> arrange(desc(mean)) |> head(1))$neighbors` for the classifier.
1038+
$K =$ `r best_k` for the classifier.
9671039

9681040
### Under/Overfitting
9691041

@@ -987,10 +1059,10 @@ knn_results <- workflow() |>
9871059
tune_grid(resamples = cancer_vfold, grid = k_lots) |>
9881060
collect_metrics()
9891061
990-
accuracies <- knn_results |>
1062+
accuracies_lots <- knn_results |>
9911063
filter(.metric == "accuracy")
9921064
993-
accuracy_vs_k_lots <- ggplot(accuracies, aes(x = neighbors, y = mean)) +
1065+
accuracy_vs_k_lots <- ggplot(accuracies_lots, aes(x = neighbors, y = mean)) +
9941066
geom_point() +
9951067
geom_line() +
9961068
labs(x = "Neighbors", y = "Accuracy Estimate") +
@@ -1082,6 +1154,95 @@ a balance between the two. You can see these two effects in Figure
10821154
\@ref(fig:06-decision-grid-K), which shows how the classifier changes as
10831155
we set the number of neighbors $K$ to 1, 7, 20, and 300.
10841156

1157+
### Evaluating on the test set
1158+
1159+
Now that we have tuned the KNN classifier and set $K =$ `r best_k`,
1160+
we are done building the model and it is time to evaluate the quality of its predictions on the held out
1161+
test data, as we did earlier in Section \@ref(eval-performance-cls2).
1162+
We first need to retrain the KNN classifier
1163+
on the entire training data set using the selected number of neighbors.
1164+
1165+
```{r 06-eval-on-test-set-after-tuning, message = FALSE, warning = FALSE}
1166+
cancer_recipe <- recipe(Class ~ Smoothness + Concavity, data = cancer_train) |>
1167+
step_scale(all_predictors()) |>
1168+
step_center(all_predictors())
1169+
1170+
knn_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = best_k) |>
1171+
set_engine("kknn") |>
1172+
set_mode("classification")
1173+
1174+
knn_fit <- workflow() |>
1175+
add_recipe(cancer_recipe) |>
1176+
add_model(knn_spec) |>
1177+
fit(data = cancer_train)
1178+
1179+
knn_fit
1180+
```
1181+
1182+
Then to make predictions and assess the estimated accuracy of the best model on the test data, we use the
1183+
`predict` and `metrics` functions as we did earlier in the chapter. We can then pass those predictions to
1184+
the `precision`, `recall`, and `conf_mat` functions to assess the estimated precision and recall, and print a confusion matrix.
1185+
1186+
```{r 06-predictions-after-tuning, message = FALSE, warning = FALSE}
1187+
cancer_test_predictions <- predict(knn_fit, cancer_test) |>
1188+
bind_cols(cancer_test)
1189+
1190+
cancer_test_predictions |>
1191+
metrics(truth = Class, estimate = .pred_class) |>
1192+
filter(.metric == "accuracy")
1193+
```
1194+
1195+
```{r 06-prec-after-tuning, message = FALSE, warning = FALSE}
1196+
cancer_test_predictions |>
1197+
precision(truth = Class, estimate = .pred_class, event_level="first")
1198+
```
1199+
1200+
```{r 06-rec-after-tuning, message = FALSE, warning = FALSE}
1201+
cancer_test_predictions |>
1202+
recall(truth = Class, estimate = .pred_class, event_level="first")
1203+
```
1204+
1205+
```{r 06-confusion-matrix-after-tuning, message = FALSE, warning = FALSE}
1206+
confusion <- cancer_test_predictions |>
1207+
conf_mat(truth = Class, estimate = .pred_class)
1208+
confusion
1209+
```
1210+
1211+
```{r 06-predictions-after-tuning-acc-save-hidden, echo = FALSE, message = FALSE, warning = FALSE}
1212+
cancer_acc_tuned <- cancer_test_predictions |>
1213+
metrics(truth = Class, estimate = .pred_class) |>
1214+
filter(.metric == "accuracy") |>
1215+
pull(.estimate)
1216+
cancer_prec_tuned <- cancer_test_predictions |>
1217+
precision(truth = Class, estimate = .pred_class, event_level="first") |>
1218+
pull(.estimate)
1219+
cancer_rec_tuned <- cancer_test_predictions |>
1220+
recall(truth = Class, estimate = .pred_class, event_level="first") |>
1221+
pull(.estimate)
1222+
```
1223+
1224+
At first glance, this is a bit surprising: the accuracy of the classifier
1225+
has only changed a small amount despite tuning the number of neighbors! Our first model
1226+
with $K =$ 3 (before we knew how to tune) had an estimated accuracy of `r round(100*cancer_acc_1$.estimate, 0)`%,
1227+
while the tuned model with $K =$ `r best_k` had an estimated accuracy
1228+
of `r round(100*cancer_acc_tuned, 0)`%.
1229+
Upon examining Figure \@ref(fig:06-find-k) again to see the
1230+
cross validation accuracy estimates for a range of neighbors, this result
1231+
becomes much less surprising. From `r min(accuracies$neighbors)` to around `r max(accuracies$neighbors)` neighbors, the cross
1232+
validation accuracy estimate varies only by around `r round(3*sd(100*accuracies$mean), 0)`%, with
1233+
each estimate having a standard error around `r round(mean(100*accuracies$std_err), 0)`%.
1234+
Since the cross-validation accuracy estimates the test set accuracy,
1235+
the fact that the test set accuracy also doesn't change much is expected.
1236+
Also note that the $K =$ 3 model had a precision
1237+
precision of `r round(100*cancer_prec_1$.estimate, 0)`% and recall of `r round(100*cancer_rec_1$.estimate, 0)`%,
1238+
while the tuned model had
1239+
a precision of `r round(100*cancer_prec_tuned, 0)`% and recall of `r round(100*cancer_rec_tuned, 0)`%.
1240+
Given that the recall decreased&mdash;remember, in this application, recall
1241+
is critical to making sure we find all the patients with malignant tumors&mdash;the tuned model may actually be *less* preferred
1242+
in this setting. In any case, it is important to think critically about the result of tuning. Models tuned to
1243+
maximize accuracy are not necessarily better for a given application.
1244+
1245+
10851246
## Summary
10861247

10871248
Classification algorithms use one or more quantitative variables to predict the

source/regression1.Rmd

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ that we used earlier in the chapter (Figure \@ref(fig:07-small-eda-regr)).
305305
\index{training data}
306306
\index{test data}
307307

308+
```{r 07-sacramento-seed-before-train-test-split, echo = FALSE, message = FALSE, warning = FALSE}
309+
# hidden seed -- make sure this is the same as what appears in reg2 right before train/test split
310+
set.seed(7)
311+
```
312+
308313
```{r 07-test-train-split}
309314
sacramento_split <- initial_split(sacramento, prop = 0.75, strata = price)
310315
sacramento_train <- training(sacramento_split)
@@ -507,13 +512,13 @@ Figure \@ref(fig:07-choose-k-knn-plot). What is happening here?
507512

508513
Figure \@ref(fig:07-howK) visualizes the effect of different settings of $K$ on the
509514
regression model. Each plot shows the predicted values for house sale price from
510-
our KNN regression model on the training data for 6 different values for $K$: 1, 3, `r kmin`, 41, 250, and 680 (almost the entire training set).
515+
our KNN regression model on the training data for 6 different values for $K$: 1, 3, 25, `r kmin`, 250, and 680 (almost the entire training set).
511516
For each model, we predict prices for the range of possible home sizes we
512517
observed in the data set (here 500 to 5,000 square feet) and we plot the
513518
predicted prices as a blue line.
514519

515520
```{r 07-howK, echo = FALSE, warning = FALSE, fig.height = 13, fig.width = 10,fig.cap = "Predicted values for house price (represented as a blue line) from KNN regression models for six different values for $K$."}
516-
gridvals <- c(1, 3, kmin, 41, 250, 680)
521+
gridvals <- c(1, 3, 25, kmin, 250, 680)
517522
518523
plots <- list()
519524

0 commit comments

Comments
 (0)