-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgmmPlotEllipses
More file actions
31 lines (25 loc) · 962 Bytes
/
gmmPlotEllipses
File metadata and controls
31 lines (25 loc) · 962 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs
from matplotlib.patches import Ellipse
import os
X, _ = make_blobs(n_samples=300, centers=3, random_state=42)
gmm = GaussianMixture(n_components=3, covariance_type='full', random_state=42)
gmm.fit(X)
labels = gmm.predict(X)
means = gmm.means_
covs = gmm.covariances_
fig, ax = plt.subplots()
ax.scatter(X[:, 0], X[:, 1], c=labels, s=10)
for mean, cov in zip(means, covs):
vals, vecs = np.linalg.eigh(cov)
order = vals.argsort()[::-1]
vals, vecs = vals[order], vecs[:, order]
theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
width, height = 2 * np.sqrt(vals)
ell = Ellipse(xy=mean, width=width, height=height, angle=theta, edgecolor='black', facecolor='none')
ax.add_patch(ell)
path = '/tmp/tvtoGmmPlotEllipses.png'
plt.savefig(path, dpi=150, bbox_inches='tight')
os.system(f'xdg-open {path}')