Skip to content

Commit 41d092c

Browse files
committed
Fixed formatting of labels for cluster updating plots
1 parent b9e8a1e commit 41d092c

File tree

1 file changed

+55
-25
lines changed

1 file changed

+55
-25
lines changed

clustering.Rmd

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,6 @@ There each row corresponds to an iteration,
431431
where the left column depicts the center update,
432432
and the right column depicts the reassignment of data to clusters.
433433

434-
**Center Update**                            **Label Update**
435-
436434
```{r 10-toy-kmeans-iter, echo = FALSE, warning = FALSE, fig.height = 16, fig.width = 8, fig.cap = "First four iterations of K-means clustering on the `penguin_data` example data set. Each row corresponds to an iteration, where the left column depicts the center update, and the right column depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black."}
437435
list_plot_cntrs <- vector(mode = "list", length = 4)
438436
list_plot_lbls <- vector(mode = "list", length = 4)
@@ -444,30 +442,62 @@ for (i in 1:4) {
444442
summarize_all(funs(mean))
445443
nclus <- nrow(centers)
446444
# replot with centers
447-
plt_ctr <- ggplot(penguin_data, aes(y = bill_length_standardized, x = flipper_length_standardized, color = label)) +
445+
plt_ctr <- ggplot(penguin_data, aes(y = bill_length_standardized,
446+
x = flipper_length_standardized,
447+
color = label)) +
448448
geom_point(size = 2) +
449449
xlab("Flipper Length (standardized)") +
450450
ylab("Bill Length (standardized)") +
451451
theme(legend.position = "none") +
452452
scale_color_manual(values= cbpalette) +
453-
geom_point(data = centers, aes(y = bill_length_standardized, x = flipper_length_standardized, fill = label), size = 5, shape = 21, stroke = 2, color = "black", fill = cbpalette)
453+
geom_point(data = centers,
454+
aes(y = bill_length_standardized,
455+
x = flipper_length_standardized,
456+
fill = label),
457+
size = 5,
458+
shape = 21,
459+
stroke = 2,
460+
color = "black",
461+
fill = cbpalette) +
462+
annotate("text", x = -1.15, y = 1.8, label = paste0("Iteration ", i))
463+
464+
if (i == 1) {
465+
plt_ctr <- plt_ctr +
466+
ggtitle("Center Update")
467+
}
454468

455469
# reassign labels
456470
dists <- rbind(centers, penguin_data) |>
457471
select("flipper_length_standardized", "bill_length_standardized") |>
458472
dist() |>
459473
as.matrix()
460474
dists <- as_tibble(dists[-(1:nclus), 1:nclus])
461-
penguin_data <- penguin_data |> mutate(label = apply(dists, 1, function(x) names(x)[which.min(x)]))
475+
penguin_data <- penguin_data |>
476+
mutate(label = apply(dists, 1, function(x) names(x)[which.min(x)]))
462477

463-
plt_lbl <- ggplot(penguin_data, aes(y = bill_length_standardized, x = flipper_length_standardized, color = label)) +
478+
plt_lbl <- ggplot(penguin_data,
479+
aes(y = bill_length_standardized,
480+
x = flipper_length_standardized,
481+
color = label)) +
464482
geom_point(size = 2) +
465483
xlab("Flipper Length (standardized)") +
466484
ylab("Bill Length (standardized)") +
467485
theme(legend.position = "none") +
468486
scale_color_manual(values= cbpalette) +
469-
geom_point(data = centers, aes(y = bill_length_standardized, x = flipper_length_standardized, fill = label), size = 5, shape = 21, stroke = 2, color = "black", fill = cbpalette)
487+
geom_point(data = centers,
488+
aes(y = bill_length_standardized,
489+
x = flipper_length_standardized, fill = label),
490+
size = 5,
491+
shape = 21,
492+
stroke = 2,
493+
color = "black",
494+
fill = cbpalette)
470495

496+
if (i == 1) {
497+
plt_lbl <- plt_lbl +
498+
ggtitle("Label Update")
499+
}
500+
471501
list_plot_cntrs[[i]] <- plt_ctr
472502
list_plot_lbls[[i]] <- plt_lbl
473503
}
@@ -477,13 +507,9 @@ iter_plot_list <- c(list_plot_cntrs[1], list_plot_lbls[1],
477507
list_plot_cntrs[3], list_plot_lbls[3],
478508
list_plot_cntrs[4], list_plot_lbls[4])
479509

480-
plot_grid(plotlist = iter_plot_list, ncol = 2,
481-
label_x = 0.005,
482-
label_y = 0.980,
483-
labels = c("Iteration 1", "",
484-
"Iteration 2", "",
485-
"Iteration 3", "",
486-
"Iteration 4", ""))
510+
plot_grid(plotlist = iter_plot_list,
511+
ncol = 2,
512+
rel_heights = c(1.065, 1, 1, 1))
487513
```
488514
489515
Note that at this point, we can terminate the algorithm since none of the assignments changed
@@ -531,8 +557,6 @@ plt_lbl
531557

532558
Figure \@ref(fig:10-toy-kmeans-bad-iter) shows what the iterations of K-means would look like with the unlucky random initialization shown in Figure \@ref(fig:10-toy-kmeans-bad-init).
533559

534-
**Center Update** &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp; &emsp;**Label Update**
535-
536560
```{r 10-toy-kmeans-bad-iter, echo = FALSE, warning = FALSE, fig.height = 20, fig.width = 8, fig.cap = "First five iterations of K-means clustering on the `penguin_data` example data set with a poor random initialization. Each row corresponds to an iteration, where the left column depicts the center update, and the right column depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black."}
537561
list_plot_cntrs <- vector(mode = "list", length = 5)
538562
list_plot_lbls <- vector(mode = "list", length = 5)
@@ -559,8 +583,14 @@ for (i in 1:5) {
559583
shape = 21,
560584
stroke = 2,
561585
color = "black",
562-
fill = cbpalette)
586+
fill = cbpalette) +
587+
annotate("text", x = -1.15, y = 1.8, label = paste0("Iteration ", i))
563588

589+
if (i == 1) {
590+
plt_ctr <- plt_ctr +
591+
ggtitle("Center Update")
592+
}
593+
564594
# reassign labels
565595
dists <- rbind(centers, penguin_data) |>
566596
select("flipper_length_standardized", "bill_length_standardized") |>
@@ -587,6 +617,11 @@ for (i in 1:5) {
587617
color = "black",
588618
fill = cbpalette)
589619

620+
if (i == 1) {
621+
plt_lbl <- plt_lbl +
622+
ggtitle("Label Update")
623+
}
624+
590625
list_plot_cntrs[[i]] <- plt_ctr
591626
list_plot_lbls[[i]] <- plt_lbl
592627
}
@@ -597,14 +632,9 @@ iter_plot_list <- c(list_plot_cntrs[1], list_plot_lbls[1],
597632
list_plot_cntrs[4], list_plot_lbls[4],
598633
list_plot_cntrs[5], list_plot_lbls[5])
599634

600-
plot_grid(plotlist = iter_plot_list, ncol = 2,
601-
label_x = 0.005,
602-
label_y = 0.980,
603-
labels = c("Iteration 1", "",
604-
"Iteration 2", "",
605-
"Iteration 3", "",
606-
"Iteration 4", "",
607-
"Iteration 5", ""))
635+
plot_grid(plotlist = iter_plot_list,
636+
ncol = 2,
637+
rel_heights = c(1.065, 1, 1, 1, 1))
608638
```
609639
610640
This looks like a relatively bad clustering of the data, but K-means cannot improve it.

0 commit comments

Comments
 (0)