-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathk_means.py
More file actions
94 lines (81 loc) · 3.76 KB
/
k_means.py
File metadata and controls
94 lines (81 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import random
import copy
from matplotlib import pyplot as plt
class KMeans:
def __init__(self, k, data):
self.k = k
random_index = np.random.randint(0, len(data)-1, size=self.k)
self.centroids = []
for idx in random_index:
self.centroids.append(data[idx])
self.centroids = np.array(self.centroids, float)
def best_run(self, data):
"""
Run for the best clustering, iterate till the error is 0
:param data: ndarray
"""
old_centroids = np.zeros(self.centroids.shape)
clusters = np.zeros(len(data))
error = np.linalg.norm(self.centroids - old_centroids)
while error != 0.0:
for idx in range(len(data)):
distance_i = np.linalg.norm(data[idx] - self.centroids, axis=1)
cluster_i = np.argmin(distance_i) # take the minimal distance
clusters[idx] = cluster_i
old_centroids = copy.deepcopy(self.centroids)
for label in range(self.k):
# get points in the recent cluster
points = [data[j] for j in range(len(data)) if clusters[j] == label]
# update centroids
self.centroids[label] = np.mean(points, axis=0)
error = np.linalg.norm(self.centroids - old_centroids)
colors = []
for i in range(self.k):
colors.append(np.random.shuffle([255,0,0,]))
fig, ax = plt.subplots()
for i in range(self.k):
points = np.array([data[j] for j in range(len(data)) if clusters[j] == i])
ax.scatter(points[:,0], points[:,1], s=7, color=colors[i])
ax.scatter(self.centroids[:,0], self.centroids[:,1], marker="*", s=20, color='#050505')
ax.set_title("Error = {}".format(error))
fig.savefig("k_means_plots/best_k_mean.png")
def iteration_run(self, data, iterations=1):
"""
Iterate the clustering for r times, plot of each run is saved
:param data: ndarray
:param iterations: integer
"""
colors = []
for i in range(self.k):
colors.append(np.random.shuffle([255,0,0]))
for it in range(iterations):
random_index = np.random.randint(0, len(data)-1, size=self.k)
self.centroids = []
for idx in random_index:
self.centroids.append(data[idx])
self.centroids = np.array(self.centroids, float)
old_centroids = np.zeros(self.centroids.shape)
clusters = np.zeros(len(data))
error = np.linalg.norm(self.centroids - old_centroids)
for idx in range(len(data)):
# E-step
distance_i = np.linalg.norm(data[idx] - self.centroids, axis=1)
cluster_i = np.argmin(distance_i)
clusters[idx] = cluster_i
old_centroids = copy.deepcopy(self.centroids)
# M-step
# get points in the recent cluster
points = [data[j] for j in range(len(data)) if clusters[j] == cluster_i]
# update centroids
self.centroids[cluster_i] = np.mean(points, axis=0)
error = np.linalg.norm(self.centroids - old_centroids)
# plot and save the clutering result
fig, ax = plt.subplots()
for i in range(self.k):
points = np.array([data[j] for j in range(len(data)) if clusters[j] == i])
ax.scatter(points[:,0], points[:,1], s=7, color=colors[i-len(colors)])
ax.scatter(self.centroids[:,0], self.centroids[:,1], marker="*", s=20, color='#050505')
print(error)
ax.set_title("Error = {}".format(error))
fig.savefig("k_means_plots/km_{}.png".format(it))