Skip to content

Commit ec873a4

Browse files
mgaido91srowen
authored andcommitted
[SPARK-14516][FOLLOWUP] Adding ClusteringEvaluator to examples
## What changes were proposed in this pull request? In SPARK-14516 we have introduced ClusteringEvaluator, but we didn't put any reference in the documentation and the examples were still relying on the sum of squared errors to show a way to evaluate the clustering model. The PR adds the ClusteringEvaluator in the examples. ## How was this patch tested? Manual runs of the examples. Author: Marco Gaido <[email protected]> Closes #19676 from mgaido91/SPARK-14516_examples.
1 parent 4289ac9 commit ec873a4

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
// $example on$
2121
import org.apache.spark.ml.clustering.KMeansModel;
2222
import org.apache.spark.ml.clustering.KMeans;
23+
import org.apache.spark.ml.evaluation.ClusteringEvaluator;
2324
import org.apache.spark.ml.linalg.Vector;
2425
import org.apache.spark.sql.Dataset;
2526
import org.apache.spark.sql.Row;
@@ -51,9 +52,14 @@ public static void main(String[] args) {
5152
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
5253
KMeansModel model = kmeans.fit(dataset);
5354

54-
// Evaluate clustering by computing Within Set Sum of Squared Errors.
55-
double WSSSE = model.computeCost(dataset);
56-
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
55+
// Make predictions
56+
Dataset<Row> predictions = model.transform(dataset);
57+
58+
// Evaluate clustering by computing Silhouette score
59+
ClusteringEvaluator evaluator = new ClusteringEvaluator();
60+
61+
double silhouette = evaluator.evaluate(predictions);
62+
System.out.println("Silhouette with squared euclidean distance = " + silhouette);
5763

5864
// Shows the result.
5965
Vector[] centers = model.clusterCenters();

examples/src/main/python/ml/kmeans_example.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# $example on$
2121
from pyspark.ml.clustering import KMeans
22+
from pyspark.ml.evaluation import ClusteringEvaluator
2223
# $example off$
2324

2425
from pyspark.sql import SparkSession
@@ -45,9 +46,14 @@
4546
kmeans = KMeans().setK(2).setSeed(1)
4647
model = kmeans.fit(dataset)
4748

48-
# Evaluate clustering by computing Within Set Sum of Squared Errors.
49-
wssse = model.computeCost(dataset)
50-
print("Within Set Sum of Squared Errors = " + str(wssse))
49+
# Make predictions
50+
predictions = model.transform(dataset)
51+
52+
# Evaluate clustering by computing Silhouette score
53+
evaluator = ClusteringEvaluator()
54+
55+
silhouette = evaluator.evaluate(predictions)
56+
print("Silhouette with squared euclidean distance = " + str(silhouette))
5157

5258
# Shows the result.
5359
centers = model.clusterCenters()

examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.apache.spark.examples.ml
2121

2222
// $example on$
2323
import org.apache.spark.ml.clustering.KMeans
24+
import org.apache.spark.ml.evaluation.ClusteringEvaluator
2425
// $example off$
2526
import org.apache.spark.sql.SparkSession
2627

@@ -47,9 +48,14 @@ object KMeansExample {
4748
val kmeans = new KMeans().setK(2).setSeed(1L)
4849
val model = kmeans.fit(dataset)
4950

50-
// Evaluate clustering by computing Within Set Sum of Squared Errors.
51-
val WSSSE = model.computeCost(dataset)
52-
println(s"Within Set Sum of Squared Errors = $WSSSE")
51+
// Make predictions
52+
val predictions = model.transform(dataset)
53+
54+
// Evaluate clustering by computing Silhouette score
55+
val evaluator = new ClusteringEvaluator()
56+
57+
val silhouette = evaluator.evaluate(predictions)
58+
println(s"Silhouette with squared euclidean distance = $silhouette")
5359

5460
// Shows the result.
5561
println("Cluster Centers: ")

0 commit comments

Comments
 (0)