@@ -431,8 +431,6 @@ There each row corresponds to an iteration,
431
431
where the left column depicts the center update,
432
432
and the right column depicts the reassignment of data to clusters.
433
433
434
- ** Center Update** &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; ** Label Update**
435
-
436
434
```{r 10-toy-kmeans-iter, echo = FALSE, warning = FALSE, fig.height = 16, fig.width = 8, fig.cap = "First four iterations of K-means clustering on the ` penguin_data ` example data set. Each row corresponds to an iteration, where the left column depicts the center update, and the right column depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black."}
437
435
list_plot_cntrs <- vector(mode = "list", length = 4)
438
436
list_plot_lbls <- vector(mode = "list", length = 4)
@@ -444,30 +442,62 @@ for (i in 1:4) {
444
442
summarize_all(funs(mean))
445
443
nclus <- nrow(centers)
446
444
# replot with centers
447
- plt_ctr <- ggplot(penguin_data, aes(y = bill_length_standardized, x = flipper_length_standardized, color = label)) +
445
+ plt_ctr <- ggplot(penguin_data, aes(y = bill_length_standardized,
446
+ x = flipper_length_standardized,
447
+ color = label)) +
448
448
geom_point(size = 2) +
449
449
xlab("Flipper Length (standardized)") +
450
450
ylab("Bill Length (standardized)") +
451
451
theme(legend.position = "none") +
452
452
scale_color_manual(values= cbpalette) +
453
- geom_point(data = centers, aes(y = bill_length_standardized, x = flipper_length_standardized, fill = label), size = 5, shape = 21, stroke = 2, color = "black", fill = cbpalette)
453
+ geom_point(data = centers,
454
+ aes(y = bill_length_standardized,
455
+ x = flipper_length_standardized,
456
+ fill = label),
457
+ size = 5,
458
+ shape = 21,
459
+ stroke = 2,
460
+ color = "black",
461
+ fill = cbpalette) +
462
+ annotate("text", x = -1.15, y = 1.8, label = paste0("Iteration ", i))
463
+
464
+ if (i == 1) {
465
+ plt_ctr <- plt_ctr +
466
+ ggtitle("Center Update")
467
+ }
454
468
455
469
# reassign labels
456
470
dists <- rbind(centers, penguin_data) |>
457
471
select("flipper_length_standardized", "bill_length_standardized") |>
458
472
dist() |>
459
473
as.matrix()
460
474
dists <- as_tibble(dists[ -(1: nclus ), 1: nclus ] )
461
- penguin_data <- penguin_data |> mutate(label = apply(dists, 1, function(x) names(x)[ which.min(x)] ))
475
+ penguin_data <- penguin_data |>
476
+ mutate(label = apply(dists, 1, function(x) names(x)[ which.min(x)] ))
462
477
463
- plt_lbl <- ggplot(penguin_data, aes(y = bill_length_standardized, x = flipper_length_standardized, color = label)) +
478
+ plt_lbl <- ggplot(penguin_data,
479
+ aes(y = bill_length_standardized,
480
+ x = flipper_length_standardized,
481
+ color = label)) +
464
482
geom_point(size = 2) +
465
483
xlab("Flipper Length (standardized)") +
466
484
ylab("Bill Length (standardized)") +
467
485
theme(legend.position = "none") +
468
486
scale_color_manual(values= cbpalette) +
469
- geom_point(data = centers, aes(y = bill_length_standardized, x = flipper_length_standardized, fill = label), size = 5, shape = 21, stroke = 2, color = "black", fill = cbpalette)
487
+ geom_point(data = centers,
488
+ aes(y = bill_length_standardized,
489
+ x = flipper_length_standardized, fill = label),
490
+ size = 5,
491
+ shape = 21,
492
+ stroke = 2,
493
+ color = "black",
494
+ fill = cbpalette)
470
495
496
+ if (i == 1) {
497
+ plt_lbl <- plt_lbl +
498
+ ggtitle("Label Update")
499
+ }
500
+
471
501
list_plot_cntrs[[ i]] <- plt_ctr
472
502
list_plot_lbls[[ i]] <- plt_lbl
473
503
}
@@ -477,13 +507,9 @@ iter_plot_list <- c(list_plot_cntrs[1], list_plot_lbls[1],
477
507
list_plot_cntrs[ 3] , list_plot_lbls[ 3] ,
478
508
list_plot_cntrs[ 4] , list_plot_lbls[ 4] )
479
509
480
- plot_grid(plotlist = iter_plot_list, ncol = 2,
481
- label_x = 0.005,
482
- label_y = 0.980,
483
- labels = c("Iteration 1", "",
484
- "Iteration 2", "",
485
- "Iteration 3", "",
486
- "Iteration 4", ""))
510
+ plot_grid(plotlist = iter_plot_list,
511
+ ncol = 2,
512
+ rel_heights = c(1.065, 1, 1, 1))
487
513
```
488
514
489
515
Note that at this point, we can terminate the algorithm since none of the assignments changed
@@ -531,8 +557,6 @@ plt_lbl
531
557
532
558
Figure \@ ref(fig:10-toy-kmeans-bad-iter) shows what the iterations of K-means would look like with the unlucky random initialization shown in Figure \@ ref(fig:10-toy-kmeans-bad-init).
533
559
534
- ** Center Update** &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; &emsp ; ** Label Update**
535
-
536
560
```{r 10-toy-kmeans-bad-iter, echo = FALSE, warning = FALSE, fig.height = 20, fig.width = 8, fig.cap = "First five iterations of K-means clustering on the ` penguin_data ` example data set with a poor random initialization. Each row corresponds to an iteration, where the left column depicts the center update, and the right column depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black."}
537
561
list_plot_cntrs <- vector(mode = "list", length = 5)
538
562
list_plot_lbls <- vector(mode = "list", length = 5)
@@ -559,8 +583,14 @@ for (i in 1:5) {
559
583
shape = 21,
560
584
stroke = 2,
561
585
color = "black",
562
- fill = cbpalette)
586
+ fill = cbpalette) +
587
+ annotate("text", x = -1.15, y = 1.8, label = paste0("Iteration ", i))
563
588
589
+ if (i == 1) {
590
+ plt_ctr <- plt_ctr +
591
+ ggtitle("Center Update")
592
+ }
593
+
564
594
# reassign labels
565
595
dists <- rbind(centers, penguin_data) |>
566
596
select("flipper_length_standardized", "bill_length_standardized") |>
@@ -587,6 +617,11 @@ for (i in 1:5) {
587
617
color = "black",
588
618
fill = cbpalette)
589
619
620
+ if (i == 1) {
621
+ plt_lbl <- plt_lbl +
622
+ ggtitle("Label Update")
623
+ }
624
+
590
625
list_plot_cntrs[[ i]] <- plt_ctr
591
626
list_plot_lbls[[ i]] <- plt_lbl
592
627
}
@@ -597,14 +632,9 @@ iter_plot_list <- c(list_plot_cntrs[1], list_plot_lbls[1],
597
632
list_plot_cntrs[ 4] , list_plot_lbls[ 4] ,
598
633
list_plot_cntrs[ 5] , list_plot_lbls[ 5] )
599
634
600
- plot_grid(plotlist = iter_plot_list, ncol = 2,
601
- label_x = 0.005,
602
- label_y = 0.980,
603
- labels = c("Iteration 1", "",
604
- "Iteration 2", "",
605
- "Iteration 3", "",
606
- "Iteration 4", "",
607
- "Iteration 5", ""))
635
+ plot_grid(plotlist = iter_plot_list,
636
+ ncol = 2,
637
+ rel_heights = c(1.065, 1, 1, 1, 1))
608
638
```
609
639
610
640
This looks like a relatively bad clustering of the data, but K-means cannot improve it.
0 commit comments