Skip to content

Commit f97d5d4

Browse files
author
Robert Kruszewski
committed
Revert "[SPARK-25867][ML] Remove KMeans computeCost"
This reverts commit dd8c179.
1 parent f160948 commit f97d5d4

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,22 @@ class KMeansModel private[ml] (
143143
@Since("2.0.0")
144144
def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML)
145145

146+
/**
147+
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
148+
* model on the given data.
149+
*
150+
* @deprecated This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator
151+
* instead. You can also get the cost on the training dataset in the summary.
152+
*/
153+
@deprecated("This method is deprecated and will be removed in 3.0.0. Use ClusteringEvaluator " +
154+
"instead. You can also get the cost on the training dataset in the summary.", "2.4.0")
155+
@Since("2.0.0")
156+
def computeCost(dataset: Dataset[_]): Double = {
157+
SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
158+
val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
159+
parentModel.computeCost(data)
160+
}
161+
146162
/**
147163
* Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
148164
*

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
117117
assert(clusters === Set(0, 1, 2, 3, 4))
118118
}
119119

120+
assert(model.computeCost(dataset) < 0.1)
120121
assert(model.hasParent)
121122

122123
// Check validity of model summary
@@ -131,6 +132,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
131132
}
132133
assert(summary.cluster.columns === Array(predictionColName))
133134
assert(summary.trainingCost < 0.1)
135+
assert(model.computeCost(dataset) == summary.trainingCost)
134136
val clusterSizes = summary.clusterSizes
135137
assert(clusterSizes.length === k)
136138
assert(clusterSizes.sum === numRows)
@@ -199,15 +201,15 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
199201
}
200202

201203
test("KMean with Array input") {
202-
def trainAndGetCost(dataset: Dataset[_]): Double = {
204+
def trainAndComputeCost(dataset: Dataset[_]): Double = {
203205
val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset)
204-
model.summary.trainingCost
206+
model.computeCost(dataset)
205207
}
206208

207209
val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset)
208-
val trueCost = trainAndGetCost(newDataset)
209-
val doubleArrayCost = trainAndGetCost(newDatasetD)
210-
val floatArrayCost = trainAndGetCost(newDatasetF)
210+
val trueCost = trainAndComputeCost(newDataset)
211+
val doubleArrayCost = trainAndComputeCost(newDatasetD)
212+
val floatArrayCost = trainAndComputeCost(newDatasetF)
211213

212214
// checking the cost is fine enough as a sanity check
213215
assert(trueCost ~== doubleArrayCost absTol 1e-6)

project/MimaExcludes.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ object MimaExcludes {
3636

3737
// Exclude rules for 3.0.x
3838
lazy val v30excludes = v24excludes ++ Seq(
39-
// [SPARK-25867] Remove KMeans computeCost
40-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"),
41-
4239
// [SPARK-26127] Remove deprecated setters from tree regression and classification models
4340
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"),
4441
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"),

python/pyspark/ml/clustering.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ def clusterCenters(self):
335335
"""Get the cluster centers, represented as a list of NumPy arrays."""
336336
return [c.toArray() for c in self._call_java("clusterCenters")]
337337

338+
@since("2.0.0")
339+
def computeCost(self, dataset):
340+
"""
341+
Return the K-means cost (sum of squared distances of points to their nearest center)
342+
for this model on the given data.
343+
344+
..note:: Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator instead.
345+
You can also get the cost on the training dataset in the summary.
346+
"""
347+
warnings.warn("Deprecated in 2.4.0. It will be removed in 3.0.0. Use ClusteringEvaluator "
348+
"instead. You can also get the cost on the training dataset in the summary.",
349+
DeprecationWarning)
350+
return self._call_java("computeCost", dataset)
351+
338352
@property
339353
@since("2.1.0")
340354
def hasSummary(self):
@@ -373,6 +387,8 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
373387
>>> centers = model.clusterCenters()
374388
>>> len(centers)
375389
2
390+
>>> model.computeCost(df)
391+
2.0
376392
>>> transformed = model.transform(df).select("features", "prediction")
377393
>>> rows = transformed.collect()
378394
>>> rows[0].prediction == rows[1].prediction

0 commit comments

Comments
 (0)