29
29
BisectingKMeans ,
30
30
BisectingKMeansModel ,
31
31
BisectingKMeansSummary ,
32
+ GaussianMixture ,
33
+ GaussianMixtureModel ,
34
+ GaussianMixtureSummary ,
32
35
)
33
36
34
37
35
38
class ClusteringTestsMixin :
36
- @property
37
- def df (self ):
38
- return (
39
+ def test_kmeans (self ):
40
+ df = (
39
41
self .spark .createDataFrame (
40
42
[
41
43
(1 , 1.0 , Vectors .dense ([- 0.1 , - 0.05 ])),
@@ -49,11 +51,9 @@ def df(self):
49
51
)
50
52
.coalesce (1 )
51
53
.sortWithinPartitions ("index" )
54
+ .select ("weight" , "features" )
52
55
)
53
56
54
- def test_kmeans (self ):
55
- df = self .df .select ("weight" , "features" )
56
-
57
57
km = KMeans (
58
58
k = 2 ,
59
59
maxIter = 2 ,
@@ -68,11 +68,7 @@ def test_kmeans(self):
68
68
# self.assertEqual(model.numFeatures, 2)
69
69
70
70
output = model .transform (df )
71
- expected_cols = [
72
- "weight" ,
73
- "features" ,
74
- "prediction" ,
75
- ]
71
+ expected_cols = ["weight" , "features" , "prediction" ]
76
72
self .assertEqual (output .columns , expected_cols )
77
73
self .assertEqual (output .count (), 6 )
78
74
@@ -107,7 +103,22 @@ def test_kmeans(self):
107
103
self .assertEqual (str (model ), str (model2 ))
108
104
109
105
def test_bisecting_kmeans (self ):
110
- df = self .df .select ("weight" , "features" )
106
+ df = (
107
+ self .spark .createDataFrame (
108
+ [
109
+ (1 , 1.0 , Vectors .dense ([- 0.1 , - 0.05 ])),
110
+ (2 , 2.0 , Vectors .dense ([- 0.01 , - 0.1 ])),
111
+ (3 , 3.0 , Vectors .dense ([0.9 , 0.8 ])),
112
+ (4 , 1.0 , Vectors .dense ([0.75 , 0.935 ])),
113
+ (5 , 1.0 , Vectors .dense ([- 0.83 , - 0.68 ])),
114
+ (6 , 1.0 , Vectors .dense ([- 0.91 , - 0.76 ])),
115
+ ],
116
+ ["index" , "weight" , "features" ],
117
+ )
118
+ .coalesce (1 )
119
+ .sortWithinPartitions ("index" )
120
+ .select ("weight" , "features" )
121
+ )
111
122
112
123
bkm = BisectingKMeans (
113
124
k = 2 ,
@@ -125,11 +136,7 @@ def test_bisecting_kmeans(self):
125
136
# self.assertEqual(model.numFeatures, 2)
126
137
127
138
output = model .transform (df )
128
- expected_cols = [
129
- "weight" ,
130
- "features" ,
131
- "prediction" ,
132
- ]
139
+ expected_cols = ["weight" , "features" , "prediction" ]
133
140
self .assertEqual (output .columns , expected_cols )
134
141
self .assertEqual (output .count (), 6 )
135
142
@@ -166,6 +173,94 @@ def test_bisecting_kmeans(self):
166
173
model2 = BisectingKMeansModel .load (d )
167
174
self .assertEqual (str (model ), str (model2 ))
168
175
176
+ def test_gaussian_mixture (self ):
177
+ df = (
178
+ self .spark .createDataFrame (
179
+ [
180
+ (1 , 1.0 , Vectors .dense ([- 0.1 , - 0.05 ])),
181
+ (2 , 2.0 , Vectors .dense ([- 0.01 , - 0.1 ])),
182
+ (3 , 3.0 , Vectors .dense ([0.9 , 0.8 ])),
183
+ (4 , 1.0 , Vectors .dense ([0.75 , 0.935 ])),
184
+ (5 , 1.0 , Vectors .dense ([- 0.83 , - 0.68 ])),
185
+ (6 , 1.0 , Vectors .dense ([- 0.91 , - 0.76 ])),
186
+ ],
187
+ ["index" , "weight" , "features" ],
188
+ )
189
+ .coalesce (1 )
190
+ .sortWithinPartitions ("index" )
191
+ .select ("weight" , "features" )
192
+ )
193
+
194
+ gmm = GaussianMixture (
195
+ k = 2 ,
196
+ maxIter = 2 ,
197
+ weightCol = "weight" ,
198
+ seed = 1 ,
199
+ )
200
+ self .assertEqual (gmm .getK (), 2 )
201
+ self .assertEqual (gmm .getMaxIter (), 2 )
202
+ self .assertEqual (gmm .getWeightCol (), "weight" )
203
+ self .assertEqual (gmm .getSeed (), 1 )
204
+
205
+ model = gmm .fit (df )
206
+ # TODO: support GMM.numFeatures in Python
207
+ # self.assertEqual(model.numFeatures, 2)
208
+ self .assertEqual (len (model .weights ), 2 )
209
+ self .assertTrue (
210
+ np .allclose (model .weights , [0.541014115744985 , 0.4589858842550149 ], atol = 1e-4 ),
211
+ model .weights ,
212
+ )
213
+ # TODO: support GMM.gaussians on connect
214
+ # self.assertEqual(model.gaussians, xxx)
215
+ self .assertEqual (model .gaussiansDF .columns , ["mean" , "cov" ])
216
+ self .assertEqual (model .gaussiansDF .count (), 2 )
217
+
218
+ vec = Vectors .dense (0.0 , 5.0 )
219
+ pred = model .predict (vec )
220
+ self .assertTrue (np .allclose (pred , 0 , atol = 1e-4 ), pred )
221
+ pred = model .predictProbability (vec )
222
+ self .assertTrue (np .allclose (pred .toArray (), [0.5 , 0.5 ], atol = 1e-4 ), pred )
223
+
224
+ output = model .transform (df )
225
+ expected_cols = ["weight" , "features" , "probability" , "prediction" ]
226
+ self .assertEqual (output .columns , expected_cols )
227
+ self .assertEqual (output .count (), 6 )
228
+
229
+ # Model summary
230
+ self .assertTrue (model .hasSummary )
231
+ summary = model .summary
232
+ self .assertTrue (isinstance (summary , GaussianMixtureSummary ))
233
+ self .assertEqual (summary .k , 2 )
234
+ self .assertEqual (summary .numIter , 2 )
235
+ self .assertEqual (len (summary .clusterSizes ), 2 )
236
+ self .assertEqual (summary .clusterSizes , [3 , 3 ])
237
+ ll = summary .logLikelihood
238
+ self .assertTrue (ll < 0 , ll )
239
+ self .assertTrue (np .allclose (ll , - 1.311264553744033 , atol = 1e-4 ), ll )
240
+
241
+ self .assertEqual (summary .featuresCol , "features" )
242
+ self .assertEqual (summary .predictionCol , "prediction" )
243
+ self .assertEqual (summary .probabilityCol , "probability" )
244
+
245
+ self .assertEqual (summary .cluster .columns , ["prediction" ])
246
+ self .assertEqual (summary .cluster .count (), 6 )
247
+
248
+ self .assertEqual (summary .predictions .columns , expected_cols )
249
+ self .assertEqual (summary .predictions .count (), 6 )
250
+
251
+ self .assertEqual (summary .probability .columns , ["probability" ])
252
+ self .assertEqual (summary .predictions .count (), 6 )
253
+
254
+ # save & load
255
+ with tempfile .TemporaryDirectory (prefix = "gaussian_mixture" ) as d :
256
+ gmm .write ().overwrite ().save (d )
257
+ gmm2 = GaussianMixture .load (d )
258
+ self .assertEqual (str (gmm ), str (gmm2 ))
259
+
260
+ model .write ().overwrite ().save (d )
261
+ model2 = GaussianMixtureModel .load (d )
262
+ self .assertEqual (str (model ), str (model2 ))
263
+
169
264
170
265
class ClusteringTests (ClusteringTestsMixin , unittest .TestCase ):
171
266
def setUp (self ) -> None :
0 commit comments