Skip to content

Commit 560dd5e

Browse files
committed
[SPARK-50928][ML][PYTHON][CONNECT] Support GaussianMixture on Connect
### What changes were proposed in this pull request? Support GaussianMixture on Connect ### Why are the changes needed? For feature parity ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added test ### Was this patch authored or co-authored using generative AI tooling? No Closes #49633 from zhengruifeng/ml_connect_gmm. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 311a4e0 commit 560dd5e

File tree

6 files changed

+123
-17
lines changed

6 files changed

+123
-17
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ org.apache.spark.ml.regression.GBTRegressor
3636
# clustering
3737
org.apache.spark.ml.clustering.KMeans
3838
org.apache.spark.ml.clustering.BisectingKMeans
39+
org.apache.spark.ml.clustering.GaussianMixture
3940

4041

4142
# recommendation

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ org.apache.spark.ml.regression.GBTRegressionModel
4242
# clustering
4343
org.apache.spark.ml.clustering.KMeansModel
4444
org.apache.spark.ml.clustering.BisectingKMeansModel
45+
org.apache.spark.ml.clustering.GaussianMixtureModel
4546

4647
# recommendation
4748
org.apache.spark.ml.recommendation.ALSModel

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ class GaussianMixtureModel private[ml] (
9393
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
9494
with HasTrainingSummary[GaussianMixtureSummary] {
9595

96+
private[ml] def this() = this(Identifiable.randomUID("gmm"),
97+
Array.emptyDoubleArray, Array.empty)
98+
9699
@Since("3.0.0")
97100
lazy val numFeatures: Int = gaussians.head.mean.size
98101

python/pyspark/ml/clustering.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def gaussians(self) -> List[MultivariateGaussian]:
241241

242242
@property
243243
@since("2.0.0")
244+
@try_remote_attribute_relation
244245
def gaussiansDF(self) -> DataFrame:
245246
"""
246247
Retrieve Gaussian distributions as a DataFrame.
@@ -542,6 +543,7 @@ def probabilityCol(self) -> str:
542543

543544
@property
544545
@since("2.1.0")
546+
@try_remote_attribute_relation
545547
def probability(self) -> DataFrame:
546548
"""
547549
DataFrame of probabilities of each cluster for each training data point.

python/pyspark/ml/tests/test_clustering.py

Lines changed: 112 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
BisectingKMeans,
3030
BisectingKMeansModel,
3131
BisectingKMeansSummary,
32+
GaussianMixture,
33+
GaussianMixtureModel,
34+
GaussianMixtureSummary,
3235
)
3336

3437

3538
class ClusteringTestsMixin:
36-
@property
37-
def df(self):
38-
return (
39+
def test_kmeans(self):
40+
df = (
3941
self.spark.createDataFrame(
4042
[
4143
(1, 1.0, Vectors.dense([-0.1, -0.05])),
@@ -49,11 +51,9 @@ def df(self):
4951
)
5052
.coalesce(1)
5153
.sortWithinPartitions("index")
54+
.select("weight", "features")
5255
)
5356

54-
def test_kmeans(self):
55-
df = self.df.select("weight", "features")
56-
5757
km = KMeans(
5858
k=2,
5959
maxIter=2,
@@ -68,11 +68,7 @@ def test_kmeans(self):
6868
# self.assertEqual(model.numFeatures, 2)
6969

7070
output = model.transform(df)
71-
expected_cols = [
72-
"weight",
73-
"features",
74-
"prediction",
75-
]
71+
expected_cols = ["weight", "features", "prediction"]
7672
self.assertEqual(output.columns, expected_cols)
7773
self.assertEqual(output.count(), 6)
7874

@@ -107,7 +103,22 @@ def test_kmeans(self):
107103
self.assertEqual(str(model), str(model2))
108104

109105
def test_bisecting_kmeans(self):
110-
df = self.df.select("weight", "features")
106+
df = (
107+
self.spark.createDataFrame(
108+
[
109+
(1, 1.0, Vectors.dense([-0.1, -0.05])),
110+
(2, 2.0, Vectors.dense([-0.01, -0.1])),
111+
(3, 3.0, Vectors.dense([0.9, 0.8])),
112+
(4, 1.0, Vectors.dense([0.75, 0.935])),
113+
(5, 1.0, Vectors.dense([-0.83, -0.68])),
114+
(6, 1.0, Vectors.dense([-0.91, -0.76])),
115+
],
116+
["index", "weight", "features"],
117+
)
118+
.coalesce(1)
119+
.sortWithinPartitions("index")
120+
.select("weight", "features")
121+
)
111122

112123
bkm = BisectingKMeans(
113124
k=2,
@@ -125,11 +136,7 @@ def test_bisecting_kmeans(self):
125136
# self.assertEqual(model.numFeatures, 2)
126137

127138
output = model.transform(df)
128-
expected_cols = [
129-
"weight",
130-
"features",
131-
"prediction",
132-
]
139+
expected_cols = ["weight", "features", "prediction"]
133140
self.assertEqual(output.columns, expected_cols)
134141
self.assertEqual(output.count(), 6)
135142

@@ -166,6 +173,94 @@ def test_bisecting_kmeans(self):
166173
model2 = BisectingKMeansModel.load(d)
167174
self.assertEqual(str(model), str(model2))
168175

176+
def test_gaussian_mixture(self):
177+
df = (
178+
self.spark.createDataFrame(
179+
[
180+
(1, 1.0, Vectors.dense([-0.1, -0.05])),
181+
(2, 2.0, Vectors.dense([-0.01, -0.1])),
182+
(3, 3.0, Vectors.dense([0.9, 0.8])),
183+
(4, 1.0, Vectors.dense([0.75, 0.935])),
184+
(5, 1.0, Vectors.dense([-0.83, -0.68])),
185+
(6, 1.0, Vectors.dense([-0.91, -0.76])),
186+
],
187+
["index", "weight", "features"],
188+
)
189+
.coalesce(1)
190+
.sortWithinPartitions("index")
191+
.select("weight", "features")
192+
)
193+
194+
gmm = GaussianMixture(
195+
k=2,
196+
maxIter=2,
197+
weightCol="weight",
198+
seed=1,
199+
)
200+
self.assertEqual(gmm.getK(), 2)
201+
self.assertEqual(gmm.getMaxIter(), 2)
202+
self.assertEqual(gmm.getWeightCol(), "weight")
203+
self.assertEqual(gmm.getSeed(), 1)
204+
205+
model = gmm.fit(df)
206+
# TODO: support GMM.numFeatures in Python
207+
# self.assertEqual(model.numFeatures, 2)
208+
self.assertEqual(len(model.weights), 2)
209+
self.assertTrue(
210+
np.allclose(model.weights, [0.541014115744985, 0.4589858842550149], atol=1e-4),
211+
model.weights,
212+
)
213+
# TODO: support GMM.gaussians on connect
214+
# self.assertEqual(model.gaussians, xxx)
215+
self.assertEqual(model.gaussiansDF.columns, ["mean", "cov"])
216+
self.assertEqual(model.gaussiansDF.count(), 2)
217+
218+
vec = Vectors.dense(0.0, 5.0)
219+
pred = model.predict(vec)
220+
self.assertTrue(np.allclose(pred, 0, atol=1e-4), pred)
221+
pred = model.predictProbability(vec)
222+
self.assertTrue(np.allclose(pred.toArray(), [0.5, 0.5], atol=1e-4), pred)
223+
224+
output = model.transform(df)
225+
expected_cols = ["weight", "features", "probability", "prediction"]
226+
self.assertEqual(output.columns, expected_cols)
227+
self.assertEqual(output.count(), 6)
228+
229+
# Model summary
230+
self.assertTrue(model.hasSummary)
231+
summary = model.summary
232+
self.assertTrue(isinstance(summary, GaussianMixtureSummary))
233+
self.assertEqual(summary.k, 2)
234+
self.assertEqual(summary.numIter, 2)
235+
self.assertEqual(len(summary.clusterSizes), 2)
236+
self.assertEqual(summary.clusterSizes, [3, 3])
237+
ll = summary.logLikelihood
238+
self.assertTrue(ll < 0, ll)
239+
self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll)
240+
241+
self.assertEqual(summary.featuresCol, "features")
242+
self.assertEqual(summary.predictionCol, "prediction")
243+
self.assertEqual(summary.probabilityCol, "probability")
244+
245+
self.assertEqual(summary.cluster.columns, ["prediction"])
246+
self.assertEqual(summary.cluster.count(), 6)
247+
248+
self.assertEqual(summary.predictions.columns, expected_cols)
249+
self.assertEqual(summary.predictions.count(), 6)
250+
251+
self.assertEqual(summary.probability.columns, ["probability"])
252+
self.assertEqual(summary.predictions.count(), 6)
253+
254+
# save & load
255+
with tempfile.TemporaryDirectory(prefix="gaussian_mixture") as d:
256+
gmm.write().overwrite().save(d)
257+
gmm2 = GaussianMixture.load(d)
258+
self.assertEqual(str(gmm), str(gmm2))
259+
260+
model.write().overwrite().save(d)
261+
model2 = GaussianMixtureModel.load(d)
262+
self.assertEqual(str(model), str(model2))
263+
169264

170265
class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
171266
def setUp(self) -> None:

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ private[ml] object MLUtils {
564564
classOf[BisectingKMeansModel],
565565
Set("predict", "numFeatures", "clusterCenters", "computeCost")),
566566
(classOf[BisectingKMeansSummary], Set("trainingCost")),
567+
(
568+
classOf[GaussianMixtureModel],
569+
Set("predict", "numFeatures", "weights", "gaussians", "predictProbability", "gaussiansDF")),
570+
(classOf[GaussianMixtureSummary], Set("probability", "probabilityCol", "logLikelihood")),
567571

568572
// Recommendation Models
569573
(

0 commit comments

Comments
 (0)