Skip to content

Commit bbd4b96

Browse files
committed
[SPARK-51329][ML][PYTHON] Add numFeatures for clustering models
### What changes were proposed in this pull request? Add `numFeatures` for clustering models ### Why are the changes needed? for feature parity between python and scala, these methods were missing in python world ### Does this PR introduce _any_ user-facing change? yes, new methods ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50095 from zhengruifeng/ml_km_nf. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 2fca41c commit bbd4b96

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

python/pyspark/ml/clustering.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,14 @@ def setProbabilityCol(self, value: str) -> "GaussianMixtureModel":
214214
"""
215215
return self._set(probabilityCol=value)
216216

217+
@property
218+
@since("4.1.0")
219+
def numFeatures(self) -> int:
220+
"""
221+
Number of features, i.e., length of Vectors which this transforms.
222+
"""
223+
return self._call_java("numFeatures")
224+
217225
@property
218226
@since("2.0.0")
219227
def weights(self) -> List[float]:
@@ -686,6 +694,14 @@ def clusterCenters(self) -> List[np.ndarray]:
686694
matrix = self._call_java("clusterCenterMatrix")
687695
return [vec for vec in matrix.toArray()]
688696

697+
@property
698+
@since("4.1.0")
699+
def numFeatures(self) -> int:
700+
"""
701+
Number of features, i.e., length of Vectors which this transforms.
702+
"""
703+
return self._call_java("numFeatures")
704+
689705
@property
690706
@since("2.1.0")
691707
def summary(self) -> KMeansSummary:
@@ -1025,6 +1041,14 @@ def computeCost(self, dataset: DataFrame) -> float:
10251041
)
10261042
return self._call_java("computeCost", dataset)
10271043

1044+
@property
1045+
@since("4.1.0")
1046+
def numFeatures(self) -> int:
1047+
"""
1048+
Number of features, i.e., length of Vectors which this transforms.
1049+
"""
1050+
return self._call_java("numFeatures")
1051+
10281052
@property
10291053
@since("2.1.0")
10301054
def summary(self) -> "BisectingKMeansSummary":

python/pyspark/ml/tests/test_clustering.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,10 @@ def test_kmeans(self):
7373

7474
centers = model.clusterCenters()
7575
self.assertEqual(len(centers), 2)
76+
self.assertEqual(model.numFeatures, 2)
7677
self.assertTrue(np.allclose(centers[0], [-0.372, -0.338], atol=1e-3), centers[0])
7778
self.assertTrue(np.allclose(centers[1], [0.8625, 0.83375], atol=1e-3), centers[1])
7879

79-
# TODO: support KMeansModel.numFeatures in Python
80-
# self.assertEqual(model.numFeatures, 2)
81-
8280
output = model.transform(df)
8381
expected_cols = ["weight", "features", "prediction"]
8482
self.assertEqual(output.columns, expected_cols)
@@ -148,12 +146,10 @@ def test_bisecting_kmeans(self):
148146

149147
centers = model.clusterCenters()
150148
self.assertEqual(len(centers), 2)
149+
self.assertEqual(model.numFeatures, 2)
151150
self.assertTrue(np.allclose(centers[0], [-0.372, -0.338], atol=1e-3), centers[0])
152151
self.assertTrue(np.allclose(centers[1], [0.8625, 0.83375], atol=1e-3), centers[1])
153152

154-
# TODO: support KMeansModel.numFeatures in Python
155-
# self.assertEqual(model.numFeatures, 2)
156-
157153
output = model.transform(df)
158154
expected_cols = ["weight", "features", "prediction"]
159155
self.assertEqual(output.columns, expected_cols)
@@ -224,8 +220,7 @@ def test_gaussian_mixture(self):
224220

225221
model = gmm.fit(df)
226222
self.assertEqual(gmm.uid, model.uid)
227-
# TODO: support GMM.numFeatures in Python
228-
# self.assertEqual(model.numFeatures, 2)
223+
self.assertEqual(model.numFeatures, 2)
229224
self.assertEqual(len(model.weights), 2)
230225
self.assertTrue(
231226
np.allclose(model.weights, [0.541014115744985, 0.4589858842550149], atol=1e-4),

0 commit comments

Comments
 (0)