Skip to content

Commit bc41b08

Browse files
clustering simpler data manipulation; bugfixes
1 parent ff2e3b3 commit bc41b08

File tree

2 files changed

+77
-111
lines changed

2 files changed

+77
-111
lines changed

source/clustering.md

Lines changed: 77 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -121,38 +121,31 @@ but one is willing to provide a few informative example labels as a "seed"
121121
to guess the labels for all the data.
122122
```
123123

124-
**An illustrative example**
124+
## An illustrative example
125125

126126
```{index} Palmer penguins
127127
```
128128

129-
Here we will present an illustrative example using a data set from
129+
In this chapter we will focus on a data set from
130130
[the `palmerpenguins` R package](https://allisonhorst.github.io/palmerpenguins/) {cite:p}`palmerpenguins`. This
131131
data set was collected by Dr. Kristen Gorman and
132132
the Palmer Station, Antarctica Long Term Ecological Research Site, and includes
133-
measurements for adult penguins found near there {cite:p}`penguinpaper`. We have
134-
modified the data set for use in this chapter. Here we will focus on using two
133+
measurements for adult penguins ({numref}`09-penguins`) found near there {cite:p}`penguinpaper`.
134+
Our goal will be to use two
135135
variables—penguin bill and flipper length, both in millimeters—to determine whether
136136
there are distinct types of penguins in our data.
137137
Understanding this might help us with species discovery and classification in a data-driven
138-
way.
138+
way. Note that we have reduced the size of the data set to 18 observations and 2 variables;
139+
this will help us make clear visualizations that illustrate how clustering works for learning purposes.
139140

140141
```{figure} img/clustering/gentoo.jpg
141142
---
142143
height: 400px
143144
name: 09-penguins
144145
---
145-
Gentoo penguin.
146+
A Gentoo penguin.
146147
```
147148

148-
To learn about K-means clustering
149-
we will work with `penguin_data` in this chapter.
150-
`penguin_data` is a subset of 18 observations of the original data,
151-
which has already been standardized
152-
(remember from {numref}`Chapter %s <classification1>`
153-
that scaling is part of the standardization process).
154-
We will discuss scaling for K-means in more detail later in this chapter.
155-
156149
Before we get started, we will set a random seed.
157150
This will ensure that our analysis will be reproducible.
158151
As we will learn in more detail later in the chapter,
@@ -166,32 +159,38 @@ when choosing a starting position for each cluster.
166159
```{code-cell} ipython3
167160
import numpy as np
168161
169-
np.random.seed(149)
162+
np.random.seed(6)
170163
```
171164

172165
```{index} read function; read_csv
173166
```
174167

175-
Now we can load and preview the data.
168+
Now we can load and preview the `penguins` data.
176169

177170
```{code-cell} ipython3
178-
:tags: [remove-cell]
179-
180171
import pandas as pd
181172
182-
data = pd.read_csv(
183-
"data/penguins_toy.csv"
184-
).replace(
185-
[2, 3],
186-
[0, 2]
187-
)
173+
penguins = pd.read_csv("data/penguins.csv")
174+
penguins
188175
```
189176

177+
We will begin by using a version of the data that we have standardized, `penguins_standardized`,
178+
to illustrate how K-means clustering works (recall standardization from {numref}`Chapter %s <classification1>`).
179+
Later in this chapter, we will return to the original `penguins` data to see how to include standardization automatically
180+
in the clustering pipeline.
181+
190182
```{code-cell} ipython3
191-
import pandas as pd
183+
:tags: [remove-cell]
184+
penguins_standardized = penguins.assign(
185+
flipper_length_standardized = (penguins["flipper_length_mm"] - penguins["flipper_length_mm"].mean())/penguins["flipper_length_mm"].std(),
186+
bill_length_standardized = (penguins["bill_length_mm"] - penguins["bill_length_mm"].mean())/penguins["bill_length_mm"].std()
187+
).drop(
188+
columns = ["bill_length_mm", "flipper_length_mm"]
189+
)
190+
```
192191

193-
penguin_data = pd.read_csv("data/penguins_standardized.csv")
194-
penguin_data
192+
```{code-cell} ipython3
193+
penguins_standardized
195194
```
196195

197196
Next, we can create a scatter plot using this data set
@@ -200,7 +199,7 @@ to see if we can detect subtypes or groups in our data set.
200199
```{code-cell} ipython3
201200
import altair as alt
202201
203-
scatter_plot = alt.Chart(penguin_data).mark_circle().encode(
202+
scatter_plot = alt.Chart(penguins_standardized).mark_circle().encode(
204203
x=alt.X("flipper_length_standardized").title("Flipper Length (standardized)"),
205204
y=alt.Y("bill_length_standardized").title("Bill Length (standardized)")
206205
)
@@ -222,8 +221,7 @@ Scatter plot of standardized bill length versus standardized flipper length.
222221
```{index} altair, altair; mark_circle
223222
```
224223

225-
Based on the visualization
226-
in {numref}`scatter_plot`,
224+
Based on the visualization in {numref}`scatter_plot`,
227225
we might suspect there are a few subtypes of penguins within our data set.
228226
We can see roughly 3 groups of observations in {numref}`scatter_plot`,
229227
including:
@@ -253,8 +251,19 @@ denoted by colored scatter points.
253251

254252
```{code-cell} ipython3
255253
:tags: [remove-cell]
254+
from sklearn import set_config
255+
from sklearn.cluster import KMeans
256+
257+
# Output dataframes instead of arrays
258+
set_config(transform_output="pandas")
256259
257-
colored_scatter_plot = alt.Chart(data).mark_circle().encode(
260+
kmeans = KMeans(n_clusters=3)
261+
262+
penguin_clust = kmeans.fit(penguins_standardized)
263+
264+
penguins_clustered = penguins_standardized.assign(cluster = penguin_clust.labels_)
265+
266+
colored_scatter_plot = alt.Chart(penguins_clustered).mark_circle().encode(
258267
x=alt.X("flipper_length_standardized", title="Flipper Length (standardized)"),
259268
y=alt.Y("bill_length_standardized", title="Bill Length (standardized)"),
260269
color=alt.Color("cluster:N")
@@ -295,7 +304,7 @@ have.
295304
```{code-cell} ipython3
296305
:tags: [remove-cell]
297306
298-
clus = data[data["cluster"] == 0][["bill_length_standardized", "flipper_length_standardized"]]
307+
clus = penguins_clustered[penguins_clustered["cluster"] == 0][["bill_length_standardized", "flipper_length_standardized"]]
299308
```
300309

301310
```{index} see: within-cluster sum-of-squared-distances; WSSD
@@ -317,8 +326,9 @@ cluster containing four observations, and we are using two variables, $x$ and $y
317326
Then we would compute the coordinates, $\mu_x$ and $\mu_y$, of the cluster center via
318327

319328

320-
321-
$\mu_x = \frac{1}{4}(x_1+x_2+x_3+x_4) \quad \mu_y = \frac{1}{4}(y_1+y_2+y_3+y_4)$
329+
$$
330+
\mu_x = \frac{1}{4}(x_1+x_2+x_3+x_4) \quad \mu_y = \frac{1}{4}(y_1+y_2+y_3+y_4)
331+
$$
322332

323333
```{code-cell} ipython3
324334
:tags: [remove-cell]
@@ -362,7 +372,7 @@ in {numref}`toy-example-clus1-center`
362372
:figwidth: 700px
363373
:name: toy-example-clus1-center
364374

365-
Cluster 0 from the `penguin_data` data set example. Observations are in blue, with the cluster center highlighted in orange.
375+
Cluster 0 from the `penguins_standardized` data set example. Observations are in blue, with the cluster center highlighted in orange.
366376
:::
367377

368378
```{code-cell} ipython3
@@ -406,30 +416,30 @@ These distances are denoted by lines in {numref}`toy-example-clus1-dists` for th
406416
:figwidth: 700px
407417
:name: toy-example-clus1-dists
408418

409-
Cluster 0 from the `penguin_data` data set example. Observations are in blue, with the cluster center highlighted in orange. The distances from the observations to the cluster center are represented as black lines.
419+
Cluster 0 from the `penguins_standardized` data set example. Observations are in blue, with the cluster center highlighted in orange. The distances from the observations to the cluster center are represented as black lines.
410420
:::
411421

412422
```{code-cell} ipython3
413423
:tags: [remove-cell]
414424
415425
toy_example_all_clus_dists = alt.layer(
416426
alt.Chart(
417-
data.assign(
418-
mean_bill_length=data.groupby('cluster')['bill_length_standardized'].transform('mean'),
419-
mean_flipper_length=data.groupby('cluster')['flipper_length_standardized'].transform('mean')
427+
penguins_clustered.assign(
428+
mean_bill_length=penguins_clustered.groupby('cluster')['bill_length_standardized'].transform('mean'),
429+
mean_flipper_length=penguins_clustered.groupby('cluster')['flipper_length_standardized'].transform('mean')
420430
)
421431
).mark_rule(size=1.25).encode(
422432
alt.Y('bill_length_standardized'),
423433
alt.Y2('mean_bill_length'),
424434
alt.X('flipper_length_standardized'),
425435
alt.X2('mean_flipper_length')
426436
),
427-
alt.Chart(data).mark_circle(size=40, opacity=1).encode(
437+
alt.Chart(penguins_clustered).mark_circle(size=40, opacity=1).encode(
428438
alt.X("flipper_length_standardized"),
429439
alt.Y("bill_length_standardized"),
430440
alt.Color('cluster:N')
431441
),
432-
alt.Chart(data).mark_circle(color='coral', size=200, opacity=1).encode(
442+
alt.Chart(penguins_clustered).mark_circle(color='coral', size=200, opacity=1).encode(
433443
alt.X("mean(flipper_length_standardized)")
434444
.scale(zero=False)
435445
.title("Flipper Length (standardized)"),
@@ -442,23 +452,32 @@ toy_example_all_clus_dists = alt.layer(
442452
glue('toy-example-all-clus-dists', toy_example_all_clus_dists, display=True)
443453
```
444454

445-
The larger the value of $S^2$, the more spread out the cluster is, since large $S^2$ means that points are far from the cluster center.
446-
Note, however, that "large" is relative to *both* the scale of the variables for clustering *and* the number of points in the cluster. A cluster where points are very close to the center might still have a large $S^2$ if there are many data points in the cluster.
455+
The larger the value of $S^2$, the more spread out the cluster is, since large $S^2$ means
456+
that points are far from the cluster center. Note, however, that "large" is relative to *both* the
457+
scale of the variables for clustering *and* the number of points in the cluster. A cluster where points
458+
are very close to the center might still have a large $S^2$ if there are many data points in the cluster.
447459

448460
After we have calculated the WSSD for all the clusters,
449-
we sum them together to get the *total WSSD*.
450-
For our example,
461+
we sum them together to get the *total WSSD*. For our example,
451462
this means adding up all the squared distances for the 18 observations.
452463
These distances are denoted by black lines in
453-
{numref}`toy-example-all-clus-dists`
464+
{numref}`toy-example-all-clus-dists`.
454465

455466
:::{glue:figure} toy-example-all-clus-dists
456467
:figwidth: 700px
457468
:name: toy-example-all-clus-dists
458469

459-
All clusters from the `penguin_data` data set example. Observations are in blue, orange, and red with the cluster center highlighted in orange. The distances from the observations to each of the respective cluster centers are represented as black lines.
470+
All clusters from the `penguins_standardized` data set example. Observations are in blue, orange, and red with the cluster center highlighted in orange. The distances from the observations to each of the respective cluster centers are represented as black lines.
460471
:::
461472

473+
Since K-means uses the straight-line distance to measure the quality of a clustering,
474+
it is limited to clustering based on quantitative variables.
475+
However, note that there are variants of the K-means algorithm,
476+
as well as other clustering algorithms entirely,
477+
that use other distance metrics
478+
to allow for non-quantitative data to be clustered.
479+
These are beyond the scope of this book.
480+
462481
+++
463482

464483
### The clustering algorithm
@@ -574,17 +593,15 @@ sum of WSSDs over all the clusters, i.e., the *total WSSD*:
574593

575594
These two steps are repeated until the cluster assignments no longer change.
576595
We show what the first three iterations of K-means would look like in
577-
{numref}`toy-kmeans-iter-1`
578-
There each row corresponds to an iteration,
596+
{numref}`toy-kmeans-iter-1`. Each row corresponds to an iteration,
579597
where the left column depicts the center update,
580-
and the right column depicts the reassignment of data to clusters.
581-
598+
and the right column depicts the label update (i.e., the reassignment of data to clusters).
582599

583600
:::{glue:figure} toy-kmeans-iter-1
584601
:figwidth: 700px
585602
:name: toy-kmeans-iter-1
586603

587-
First three iterations of K-means clustering on the `penguin_data` example data set. Each pair of plots corresponds to an iteration. Within the pair, the first plot depicts the center update, and the second plot depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black.
604+
First three iterations of K-means clustering on the `penguins_standardized` example data set. Each pair of plots corresponds to an iteration. Within the pair, the first plot depicts the center update, and the second plot depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black.
588605
:::
589606

590607
+++
@@ -604,17 +621,6 @@ ways to assign the data to clusters. So at some point, the total WSSD must stop
604621
are changing, and the algorithm terminates.
605622
```
606623

607-
What kind of data is suitable for K-means clustering?
608-
In the simplest version of K-means clustering that we have presented here,
609-
the straight-line distance is used to measure the
610-
distance between observations and cluster centers.
611-
This means that only quantitative data should be used with this algorithm.
612-
There are variants on the K-means algorithm,
613-
as well as other clustering algorithms entirely,
614-
that use other distance metrics
615-
to allow for non-quantitative data to be clustered.
616-
These, however, are beyond the scope of this book.
617-
618624
```{code-cell} ipython3
619625
:tags: [remove-cell]
620626
@@ -663,7 +669,7 @@ glue('toy-kmeans-bad-iter-1', plot_kmean_iterations(4, penguin_data.copy(), cent
663669
:figwidth: 700px
664670
:name: toy-kmeans-bad-iter-1
665671

666-
First five iterations of K-means clustering on the `penguin_data` example data set with a poor random initialization. Each pair of plots corresponds to an iteration. Within the pair, the first plot depicts the center update, and the second plot depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black.
672+
First four iterations of K-means clustering on the `penguins_standardized` example data set with a poor random initialization. Each pair of plots corresponds to an iteration. Within the pair, the first plot depicts the center update, and the second plot depicts the reassignment of data to clusters. Cluster centers are indicated by larger points that are outlined in black.
667673
:::
668674

669675
This looks like a relatively bad clustering of the data, but K-means cannot improve it.
@@ -790,23 +796,9 @@ Total WSSD for K clusters ranging from 1 to 9.
790796
```
791797

792798
We can perform K-means in Python using a workflow similar to those
793-
in the earlier classification and regression chapters. We will begin
794-
by reading the original (i.e., unstandardized) subset of 18 observations
795-
from the penguins data set.
796-
797-
```{code-cell} ipython3
798-
:tags: [remove-cell]
799-
800-
unstandardized_data = pd.read_csv("data/penguins_toy.csv", usecols=["bill_length_mm", "flipper_length_mm"])
801-
unstandardized_data.to_csv("data/penguins.csv", index=False)
802-
```
803-
804-
```{code-cell} ipython3
805-
penguins = pd.read_csv("data/penguins.csv")
806-
penguins
807-
```
808-
809-
Recall that K-means clustering uses straight-line distance to decide which points are similar to
799+
in the earlier classification and regression chapters.
800+
Returning to the original (unstandardized) `penguins` data,
801+
recall that K-means clustering uses straight-line distance to decide which points are similar to
810802
each other. Therefore, the *scale* of each of the variables in the data
811803
will influence which cluster data points end up being assigned.
812804
Variables with a large scale will have a much larger
@@ -871,12 +863,6 @@ clustered_data = penguins.assign(cluster = penguin_clust[1].labels_)
871863
clustered_data
872864
```
873865

874-
Let's start by visualizing the clustering
875-
as a colored scatter plot. To do that,
876-
we will add a new column and store assign the above predictions to that. The final
877-
data frame will contain the data and the cluster assignments for
878-
each point:
879-
880866
Now that we have the cluster assignments included in the `clustered_data` data frame, we can
881867
visualize them as shown in {numref}`cluster_plot`.
882868
Note that we are plotting the *un-standardized* data here; if we for some reason wanted to
@@ -1018,17 +1004,16 @@ it is possible to have an elbow plot
10181004
where the WSSD increases at one of the steps,
10191005
causing a small bump in the line.
10201006
This is because K-means can get "stuck" in a bad solution
1021-
due to an unlucky initialization of the initial centroid positions
1007+
due to an unlucky initialization of the initial center positions
10221008
as we mentioned earlier in the chapter.
10231009

10241010
```{note}
10251011
It is rare that the KMeans function from `scikit-learn`
1026-
gets stuck in a bad solution,
1027-
because the selection of the centroid starting points
1028-
is optimized to prevent this from happening.
1012+
gets stuck in a bad solution, because `scikit-learn` tries to choose
1013+
the initial centers carefully to prevent this from happening.
10291014
If you still find yourself in a situation where you have a bump in the elbow plot,
10301015
you can increase the `n_init` parameter
1031-
to try more different starting points for the centroids.
1016+
when creating the `KMeans` object, e.g., `KMeans(n_clusters=k, n_init=10)`, to try more different random center initializations.
10321017
The larger the value the better from an analysis perspective,
10331018
but there is a trade-off that doing many clusterings could take a long time.
10341019
```

source/data/penguins_standardized.csv

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)