Skip to content

Commit 2c32f93

Browse files
committed
[SPARK-50995][ML][PYTHON][CONNECT] Support clusterCenters for KMeans and BisectingKMeans
### What changes were proposed in this pull request? Support `clusterCenters` for KMeans and BisectingKMeans, To simplify the serde of `Array[Vector]`, combine it to a `Matrix` ### Why are the changes needed? for parity ### Does this PR introduce _any_ user-facing change? yes, new API supported on connect ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49680 from zhengruifeng/ml_connect_km_cluster. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org> (cherry picked from commit 66c2920) Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent d18c899 commit 2c32f93

File tree

5 files changed

+25
-5
lines changed

5 files changed

+25
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.hadoop.fs.Path
2121

2222
import org.apache.spark.annotation.Since
2323
import org.apache.spark.ml.{Estimator, Model}
24-
import org.apache.spark.ml.linalg.Vector
24+
import org.apache.spark.ml.linalg._
2525
import org.apache.spark.ml.param._
2626
import org.apache.spark.ml.param.shared._
2727
import org.apache.spark.ml.util._
@@ -142,6 +142,9 @@ class BisectingKMeansModel private[ml] (
142142
@Since("2.0.0")
143143
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
144144

145+
private[ml] def clusterCenterMatrix: Matrix =
146+
Matrices.fromVectors(clusterCenters.toSeq)
147+
145148
/**
146149
* Computes the sum of squared distances between the input points and their corresponding cluster
147150
* centers.

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ class KMeansModel private[ml] (
187187
@Since("2.0.0")
188188
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
189189

190+
private[ml] def clusterCenterMatrix: Matrix =
191+
Matrices.fromVectors(clusterCenters.toSeq)
192+
190193
/**
191194
* Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
192195
*

python/pyspark/ml/clustering.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,8 @@ def setPredictionCol(self, value: str) -> "KMeansModel":
686686
@since("1.5.0")
687687
def clusterCenters(self) -> List[np.ndarray]:
688688
"""Get the cluster centers, represented as a list of NumPy arrays."""
689-
return [c.toArray() for c in self._call_java("clusterCenters")]
689+
matrix = self._call_java("clusterCenterMatrix")
690+
return [vec for vec in matrix.toArray()]
690691

691692
@property
692693
@since("2.1.0")
@@ -1006,7 +1007,8 @@ def setPredictionCol(self, value: str) -> "BisectingKMeansModel":
10061007
@since("2.0.0")
10071008
def clusterCenters(self) -> List[np.ndarray]:
10081009
"""Get the cluster centers, represented as a list of NumPy arrays."""
1009-
return [c.toArray() for c in self._call_java("clusterCenters")]
1010+
matrix = self._call_java("clusterCenterMatrix")
1011+
return [vec for vec in matrix.toArray()]
10101012

10111013
@since("2.0.0")
10121014
def computeCost(self, dataset: DataFrame) -> float:

python/pyspark/ml/tests/test_clustering.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def test_kmeans(self):
6969

7070
model = km.fit(df)
7171
self.assertEqual(km.uid, model.uid)
72+
73+
centers = model.clusterCenters()
74+
self.assertEqual(len(centers), 2)
75+
self.assertTrue(np.allclose(centers[0], [-0.372, -0.338], atol=1e-3), centers[0])
76+
self.assertTrue(np.allclose(centers[1], [0.8625, 0.83375], atol=1e-3), centers[1])
77+
7278
# TODO: support KMeansModel.numFeatures in Python
7379
# self.assertEqual(model.numFeatures, 2)
7480

@@ -138,6 +144,12 @@ def test_bisecting_kmeans(self):
138144

139145
model = bkm.fit(df)
140146
self.assertEqual(bkm.uid, model.uid)
147+
148+
centers = model.clusterCenters()
149+
self.assertEqual(len(centers), 2)
150+
self.assertTrue(np.allclose(centers[0], [-0.372, -0.338], atol=1e-3), centers[0])
151+
self.assertTrue(np.allclose(centers[1], [0.8625, 0.83375], atol=1e-3), centers[1])
152+
141153
# TODO: support KMeansModel.numFeatures in Python
142154
# self.assertEqual(model.numFeatures, 2)
143155

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,11 @@ private[ml] object MLUtils {
584584
(classOf[LinearRegressionTrainingSummary], Set("objectiveHistory", "totalIterations")),
585585

586586
// Clustering Models
587-
(classOf[KMeansModel], Set("predict", "numFeatures", "clusterCenters")),
587+
(classOf[KMeansModel], Set("predict", "numFeatures", "clusterCenterMatrix")),
588588
(classOf[KMeansSummary], Set("trainingCost")),
589589
(
590590
classOf[BisectingKMeansModel],
591-
Set("predict", "numFeatures", "clusterCenters", "computeCost")),
591+
Set("predict", "numFeatures", "clusterCenterMatrix", "computeCost")),
592592
(classOf[BisectingKMeansSummary], Set("trainingCost")),
593593
(
594594
classOf[GaussianMixtureModel],

0 commit comments

Comments
 (0)