Skip to content

Commit e0ef4bc

Browse files
committed
Added GaussianMixtureHelper + 2 Bug Fixes
1 parent 77b063e commit e0ef4bc

File tree

20 files changed

+846
-30
lines changed

20 files changed

+846
-30
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This is a helper package for a variety of functions as described in the Overview
44

55
# Installation
66

7-
pip install more==0.0.1b6
7+
pip install more==0.0.1b7
88

99
# Overview
1010

@@ -15,3 +15,12 @@ This is a helper package for a variety of functions
1515

1616
# Examples
1717
Check out the [examples](https://github.com/ngupta23/more/tree/master/examples) folder for details on usage
18+
19+
# Version History
20+
21+
## 0.0.1b7
22+
23+
* Added Cluster Helper for Gaussian Clusters
24+
* Fixed Bug for plot_parallel_coordinates where it was not working correctly for a multi-level categorical label
25+
* Fixed bug for pandas helper for describing categorical and numeric fields - Now it gives a warning if the dataframe does not have any categorical or numeric field when those respective describe functions are called.
26+

build/lib/more/pandas_helper/__init__.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pandas as pd
2-
# import numpy as np
32
import warnings
43

54
@pd.api.extensions.register_dataframe_accessor("helper")
@@ -42,23 +41,30 @@ def describe_categorical(self, verbose=False):
4241
the dataframe contains both numeric and categorical variables.
4342
This extension provides more flexibility
4443
"""
45-
self.__print_dashes(45)
46-
print("Summary Statictics for Categorical Variables")
47-
self.__print_dashes(45)
48-
print(self._obj[self.cat_features].describe())
49-
50-
if (verbose):
51-
self.level_counts()
44+
if (self._cat_exists()):
45+
self.__print_dashes(45)
46+
print("Summary Statictics for Categorical Variables")
47+
self.__print_dashes(45)
48+
print(self._obj[self.cat_features].describe())
49+
50+
if (verbose):
51+
self.level_counts()
52+
else:
53+
warnings.warn("Data does not have any categorical columns")
54+
5255

5356
def describe_numeric(self):
5457
"""
5558
Prints numeric variable summaries for a Pandas Dataframe
5659
Same as default behavior in pd.DataFrame.describe()
5760
"""
58-
self.__print_dashes(40)
59-
print("Summary Statictics for Numeric Variables")
60-
self.__print_dashes(40)
61-
print(self._obj[self.num_features].describe())
61+
if (self._num_exists()):
62+
self.__print_dashes(40)
63+
print("Summary Statictics for Numeric Variables")
64+
self.__print_dashes(40)
65+
print(self._obj[self.num_features].describe())
66+
else:
67+
warnings.warn("Data does not have any numeric columns")
6268

6369
def describe(self,verbose=False):
6470
"""
@@ -116,6 +122,18 @@ def __set_num_features(self):
116122
def __set_all_feature_types(self):
117123
self.__set_cat_features()
118124
self.__set_num_features()
125+
126+
def _cat_exists(self):
127+
if (len(self.cat_features) > 0):
128+
return True
129+
else:
130+
return False
131+
132+
def _num_exists(self):
133+
if (len(self.num_features) > 0):
134+
return True
135+
else:
136+
return False
119137

120138
def __print_dashes(self,num = 20):
121139
print("-"*num)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import numpy as np
2+
import pandas as pd
3+
import matplotlib.pyplot as plt
4+
5+
# this code has been manipulated from the source available on sklearn's website documentation
6+
import itertools
7+
from sklearn import metrics as mt
8+
from scipy import linalg
9+
import matplotlib as mpl
10+
from sklearn import mixture
11+
12+
class GaussianMixtureHelper:
13+
def __init__(self, X, y=None
14+
,n_components_range=range(2,3) ,cov_types = ['spherical']
15+
,metric = 'bic', random_state=101):
16+
"""
17+
Class to train and evaluate a Gaussian Mixture Cluster Model
18+
"""
19+
self.X = X
20+
self.y = y
21+
self.n_components_range = n_components_range
22+
self.cov_types = cov_types
23+
# Add exception here is metric is not of the right type
24+
self.metric = metric
25+
self.random_state = random_state
26+
self.y_pred = None
27+
self.best_gmm = None
28+
self.best_gmm_bic = None
29+
self.best_gmm_aic = None
30+
self.bic = []
31+
self.aic = []
32+
self.lowest_bic = np.infty
33+
self.lowest_aic = np.infty
34+
35+
def train(self):
36+
"""
37+
Train the Gaissian Mixture Model across a range of cluster values and covariance types
38+
"""
39+
for cov_type in self.cov_types:
40+
for n_components in self.n_components_range:
41+
# Fit a mixture of Gaussians with EM
42+
gmm = mixture.GaussianMixture(n_components=n_components
43+
,covariance_type=cov_type
44+
,random_state=self.random_state)
45+
gmm.fit(self.X)
46+
self.bic.append(gmm.bic(self.X))
47+
self.aic.append(gmm.aic(self.X))
48+
49+
if self.bic[-1] < self.lowest_bic:
50+
self.lowest_bic = self.bic[-1]
51+
self.best_gmm_bic = gmm
52+
53+
if self.aic[-1] < self.lowest_aic:
54+
self.lowest_aic = self.aic[-1]
55+
self.best_gmm_aic = gmm
56+
57+
self.set_best_model()
58+
self.y_pred = self.predict(self.X)
59+
return(self)
60+
61+
def set_metric(self, metric):
62+
self.metric = metric
63+
64+
def set_best_model(self):
65+
"""
66+
Use to set the best model to the one based on a specific metric
67+
Default Metric = 'bic'; Other Option(s): 'aic'
68+
"""
69+
if (self.metric == 'bic'):
70+
self.best_gmm = self.best_gmm_bic
71+
elif(self.metric == 'aic'):
72+
self.best_gmm = self.best_gmm_aic
73+
74+
def get_best_model(self):
75+
return(self.best_gmm)
76+
77+
def plot_metrics(self, figsize = (12,4)):
78+
# this code has been manipulated from the source available on sklearn's website documentation
79+
# plot the BIC
80+
81+
plt.figure(figsize=figsize)
82+
83+
# Plot the BIC scores
84+
spl = plt.subplot(1,2,1)
85+
color_iter = itertools.cycle(['k', 'r', 'b', 'g', 'c', 'm', 'y'])
86+
bars = []
87+
self.bic = np.array(self.bic)
88+
89+
for i, (self.cov_type, color) in enumerate(zip(self.cov_types, color_iter)):
90+
xpos = np.array(self.n_components_range) + .2 * (i - 2)
91+
bars.append(plt.bar(xpos, self.bic[i * len(self.n_components_range):
92+
(i + 1) * len(self.n_components_range)],
93+
width=.2, color=color))
94+
plt.xticks(self.n_components_range)
95+
plt.ylim([self.bic.min() * 1.01 - .01 * self.bic.max(), self.bic.max()])
96+
plt.title('BIC score per model')
97+
98+
xpos = np.min(self.n_components_range)-0.4 + np.mod(self.bic.argmin(), len(self.n_components_range)) +\
99+
.2 * np.floor(self.bic.argmin() / len(self.n_components_range))
100+
plt.text(xpos, self.bic.min() * 0.97 + .03 * self.bic.max(), '*', fontsize=14)
101+
spl.set_xlabel('Number of components')
102+
spl.legend([b[0] for b in bars], self.cov_types)
103+
104+
# Plot the AIC scores
105+
spl = plt.subplot(1,2,2)
106+
color_iter = itertools.cycle(['k', 'r', 'b', 'g', 'c', 'm', 'y'])
107+
bars = []
108+
self.aic = np.array(self.aic)
109+
110+
for i, (self.cov_type, color) in enumerate(zip(self.cov_types, color_iter)):
111+
xpos = np.array(self.n_components_range) + .2 * (i - 2)
112+
bars.append(plt.bar(xpos, self.aic[i * len(self.n_components_range):
113+
(i + 1) * len(self.n_components_range)],
114+
width=.2, color=color))
115+
plt.xticks(self.n_components_range)
116+
plt.ylim([self.aic.min() * 1.01 - .01 * self.aic.max(), self.aic.max()])
117+
plt.title('AIC score per model')
118+
119+
xpos = np.min(self.n_components_range)-0.4 + np.mod(self.aic.argmin(), len(self.n_components_range)) +\
120+
.2 * np.floor(self.aic.argmin() / len(self.n_components_range))
121+
plt.text(xpos, self.aic.min() * 0.97 + .03 * self.aic.max(), '*', fontsize=14)
122+
spl.set_xlabel('Number of components')
123+
spl.legend([b[0] for b in bars], self.cov_types)
124+
125+
plt.tight_layout()
126+
#plt.show()
127+
return(plt)
128+
129+
def predict(self,X):
130+
clf = self.get_best_model()
131+
y_pred = clf.predict(X)
132+
return(y_pred)
133+
134+
def plot_best_model(self,feat_x,feat_y):
135+
plt.figure(figsize=(12,6))
136+
splot = plt.subplot(1,1,1)
137+
138+
color_iter = itertools.cycle(['k', 'r', 'b', 'g', 'c', 'm', 'y'])
139+
clf = self.get_best_model()
140+
141+
for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covariances_,color_iter)):
142+
if len(covar.shape)<2:
143+
tmp = np.zeros((2,2))
144+
np.fill_diagonal(tmp,covar)
145+
covar = tmp
146+
elif covar.shape[0] != covar.shape[1]:
147+
covar = np.diag(covar)
148+
149+
v, w = linalg.eigh(covar)
150+
if not np.any(self.y_pred == i):
151+
continue
152+
153+
plt.scatter(self.X[self.y_pred == i][feat_x], self.X[self.y_pred == i][feat_y], 5, color=color)
154+
155+
# Plot an ellipse to show the Gaussian component
156+
angle = np.arctan2(w[0][1], w[0][0])
157+
angle = 180 * angle / np.pi # convert to degrees
158+
v *= 4
159+
ell = mpl.patches.Ellipse(mean, v[0], v[1], 180 + angle, color=color)
160+
ell.set_clip_box(splot.bbox)
161+
ell.set_alpha(.5)
162+
splot.add_artist(ell)
163+
164+
plt.title('Selected GMM')
165+
plt.show()
166+
167+
def clusters_vs_true_labels(self):
168+
self.y_pred = self.predict(self.X)
169+
num_true_classes = len(set(self.y))
170+
print(mt.confusion_matrix(self.y,self.y_pred)[0:num_true_classes,:])
171+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .GaussianMixture import GaussianMixtureHelper

build/lib/more/viz_helper/plot_parallel_coordinates.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def plot_parallel_coordinates(data, by, sample=True, frac=1.0, normalize=True
2121
else:
2222
df_sub = data.copy(deep=False)
2323

24-
df_sub[by] = df_sub[by] == 1 # converting categorical variable into number for plotting
24+
# Commenting out since the by variable could have more than 2 levels
25+
#df_sub[by] = df_sub[by] == 1 # converting categorical variable into number for plotting
2526

2627

2728
# This plot is more meaningful when values are normalized

dist/more-0.0.1b7-py3-none-any.whl

22.2 KB
Binary file not shown.

dist/more-0.0.1b7.tar.gz

14.1 KB
Binary file not shown.

0 commit comments

Comments
 (0)