|
20 | 20 | from pyspark import since, keyword_only
|
21 | 21 | from pyspark.ml.wrapper import JavaParams
|
22 | 22 | from pyspark.ml.param import Param, Params, TypeConverters
|
23 |
| -from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol |
| 23 | +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \ |
| 24 | + HasFeaturesCol |
24 | 25 | from pyspark.ml.common import inherit_doc
|
25 | 26 | from pyspark.ml.util import JavaMLReadable, JavaMLWritable
|
26 | 27 |
|
27 | 28 | __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
|
28 |
| - 'MulticlassClassificationEvaluator'] |
| 29 | + 'MulticlassClassificationEvaluator', 'ClusteringEvaluator'] |
29 | 30 |
|
30 | 31 |
|
31 | 32 | @inherit_doc
|
@@ -325,6 +326,77 @@ def setParams(self, predictionCol="prediction", labelCol="label",
|
325 | 326 | kwargs = self._input_kwargs
|
326 | 327 | return self._set(**kwargs)
|
327 | 328 |
|
| 329 | + |
| 330 | +@inherit_doc |
| 331 | +class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, |
| 332 | + JavaMLReadable, JavaMLWritable): |
| 333 | + """ |
| 334 | + .. note:: Experimental |
| 335 | +
|
| 336 | + Evaluator for Clustering results, which expects two input |
| 337 | + columns: prediction and features. |
| 338 | +
|
| 339 | + >>> from pyspark.ml.linalg import Vectors |
| 340 | + >>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), |
| 341 | + ... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0), |
| 342 | + ... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)]) |
| 343 | + >>> dataset = spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) |
| 344 | + ... |
| 345 | + >>> evaluator = ClusteringEvaluator(predictionCol="prediction") |
| 346 | + >>> evaluator.evaluate(dataset) |
| 347 | + 0.9079... |
| 348 | + >>> ce_path = temp_path + "/ce" |
| 349 | + >>> evaluator.save(ce_path) |
| 350 | + >>> evaluator2 = ClusteringEvaluator.load(ce_path) |
| 351 | + >>> str(evaluator2.getPredictionCol()) |
| 352 | + 'prediction' |
| 353 | +
|
| 354 | + .. versionadded:: 2.3.0 |
| 355 | + """ |
| 356 | + metricName = Param(Params._dummy(), "metricName", |
| 357 | + "metric name in evaluation (silhouette)", |
| 358 | + typeConverter=TypeConverters.toString) |
| 359 | + |
| 360 | + @keyword_only |
| 361 | + def __init__(self, predictionCol="prediction", featuresCol="features", |
| 362 | + metricName="silhouette"): |
| 363 | + """ |
| 364 | + __init__(self, predictionCol="prediction", featuresCol="features", \ |
| 365 | + metricName="silhouette") |
| 366 | + """ |
| 367 | + super(ClusteringEvaluator, self).__init__() |
| 368 | + self._java_obj = self._new_java_obj( |
| 369 | + "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid) |
| 370 | + self._setDefault(metricName="silhouette") |
| 371 | + kwargs = self._input_kwargs |
| 372 | + self._set(**kwargs) |
| 373 | + |
| 374 | + @since("2.3.0") |
| 375 | + def setMetricName(self, value): |
| 376 | + """ |
| 377 | + Sets the value of :py:attr:`metricName`. |
| 378 | + """ |
| 379 | + return self._set(metricName=value) |
| 380 | + |
| 381 | + @since("2.3.0") |
| 382 | + def getMetricName(self): |
| 383 | + """ |
| 384 | + Gets the value of metricName or its default value. |
| 385 | + """ |
| 386 | + return self.getOrDefault(self.metricName) |
| 387 | + |
| 388 | + @keyword_only |
| 389 | + @since("2.3.0") |
| 390 | + def setParams(self, predictionCol="prediction", featuresCol="features", |
| 391 | + metricName="silhouette"): |
| 392 | + """ |
| 393 | + setParams(self, predictionCol="prediction", featuresCol="features", \ |
| 394 | + metricName="silhouette") |
| 395 | + Sets params for clustering evaluator. |
| 396 | + """ |
| 397 | + kwargs = self._input_kwargs |
| 398 | + return self._set(**kwargs) |
| 399 | + |
328 | 400 | if __name__ == "__main__":
|
329 | 401 | import doctest
|
330 | 402 | import tempfile
|
|
0 commit comments