Skip to content

Commit 9ec0496

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-28044][ML][PYTHON] MulticlassClassificationEvaluator support more metrics
## What changes were proposed in this pull request? expose more metrics in evaluator: weightedTruePositiveRate/weightedFalsePositiveRate/weightedFMeasure/truePositiveRateByLabel/falsePositiveRateByLabel/precisionByLabel/recallByLabel/fMeasureByLabel ## How was this patch tested? existing cases and add cases Closes apache#24868 from zhengruifeng/multi_class_support_bylabel. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 7b7f16f commit 9ec0496

File tree

6 files changed

+157
-45
lines changed

6 files changed

+157
-45
lines changed

mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.evaluation
1919

2020
import org.apache.spark.annotation.{Experimental, Since}
21-
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
21+
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
2222
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
2323
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
2424
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -36,6 +36,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
3636
extends Evaluator with HasPredictionCol with HasLabelCol
3737
with HasWeightCol with DefaultParamsWritable {
3838

39+
import MulticlassClassificationEvaluator.supportedMetricNames
40+
3941
@Since("1.5.0")
4042
def this() = this(Identifiable.randomUID("mcEval"))
4143

@@ -45,12 +47,9 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
4547
* @group param
4648
*/
4749
@Since("1.5.0")
48-
val metricName: Param[String] = {
49-
val allowedParams = ParamValidators.inArray(Array("f1", "weightedPrecision",
50-
"weightedRecall", "accuracy"))
51-
new Param(this, "metricName", "metric name in evaluation " +
52-
"(f1|weightedPrecision|weightedRecall|accuracy)", allowedParams)
53-
}
50+
val metricName: Param[String] = new Param(this, "metricName",
51+
s"metric name in evaluation ${supportedMetricNames.mkString("(", "|", ")")}",
52+
ParamValidators.inArray(supportedMetricNames))
5453

5554
/** @group getParam */
5655
@Since("1.5.0")
@@ -60,6 +59,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
6059
@Since("1.5.0")
6160
def setMetricName(value: String): this.type = set(metricName, value)
6261

62+
setDefault(metricName -> "f1")
63+
6364
/** @group setParam */
6465
@Since("1.5.0")
6566
def setPredictionCol(value: String): this.type = set(predictionCol, value)
@@ -72,7 +73,39 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
7273
@Since("3.0.0")
7374
def setWeightCol(value: String): this.type = set(weightCol, value)
7475

75-
setDefault(metricName -> "f1")
76+
@Since("3.0.0")
77+
final val metricLabel: DoubleParam = new DoubleParam(this, "metricLabel",
78+
"The class whose metric will be computed in " +
79+
s"${supportedMetricNames.filter(_.endsWith("ByLabel")).mkString("(", "|", ")")}. " +
80+
"Must be >= 0. The default value is 0.",
81+
ParamValidators.gtEq(0.0))
82+
83+
/** @group getParam */
84+
@Since("3.0.0")
85+
def getMetricLabel: Double = $(metricLabel)
86+
87+
/** @group setParam */
88+
@Since("3.0.0")
89+
def setMetricLabel(value: Double): this.type = set(metricLabel, value)
90+
91+
setDefault(metricLabel -> 0.0)
92+
93+
@Since("3.0.0")
94+
final val beta: DoubleParam = new DoubleParam(this, "beta",
95+
"The beta value, which controls precision vs recall weighting, " +
96+
"used in (weightedFMeasure|fMeasureByLabel). Must be > 0. The default value is 1.",
97+
ParamValidators.gt(0.0))
98+
99+
/** @group getParam */
100+
@Since("3.0.0")
101+
def getBeta: Double = $(beta)
102+
103+
/** @group setParam */
104+
@Since("3.0.0")
105+
def setBeta(value: Double): this.type = set(beta, value)
106+
107+
setDefault(beta -> 1.0)
108+
76109

77110
@Since("2.0.0")
78111
override def evaluate(dataset: Dataset[_]): Double = {
@@ -87,17 +120,30 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
87120
case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight)
88121
}
89122
val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
90-
val metric = $(metricName) match {
123+
$(metricName) match {
91124
case "f1" => metrics.weightedFMeasure
125+
case "accuracy" => metrics.accuracy
92126
case "weightedPrecision" => metrics.weightedPrecision
93127
case "weightedRecall" => metrics.weightedRecall
94-
case "accuracy" => metrics.accuracy
128+
case "weightedTruePositiveRate" => metrics.weightedTruePositiveRate
129+
case "weightedFalsePositiveRate" => metrics.weightedFalsePositiveRate
130+
case "weightedFMeasure" => metrics.weightedFMeasure($(beta))
131+
case "truePositiveRateByLabel" => metrics.truePositiveRate($(metricLabel))
132+
case "falsePositiveRateByLabel" => metrics.falsePositiveRate($(metricLabel))
133+
case "precisionByLabel" => metrics.precision($(metricLabel))
134+
case "recallByLabel" => metrics.recall($(metricLabel))
135+
case "fMeasureByLabel" => metrics.fMeasure($(metricLabel), $(beta))
95136
}
96-
metric
97137
}
98138

99139
@Since("1.5.0")
100-
override def isLargerBetter: Boolean = true
140+
override def isLargerBetter: Boolean = {
141+
$(metricName) match {
142+
case "weightedFalsePositiveRate" => false
143+
case "falsePositiveRateByLabel" => false
144+
case _ => true
145+
}
146+
}
101147

102148
@Since("1.5.0")
103149
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
@@ -107,6 +153,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
107153
object MulticlassClassificationEvaluator
108154
extends DefaultParamsReadable[MulticlassClassificationEvaluator] {
109155

156+
private val supportedMetricNames = Array("f1", "accuracy", "weightedPrecision", "weightedRecall",
157+
"weightedTruePositiveRate", "weightedFalsePositiveRate", "weightedFMeasure",
158+
"truePositiveRateByLabel", "falsePositiveRateByLabel", "precisionByLabel", "recallByLabel",
159+
"fMeasureByLabel")
160+
110161
@Since("1.6.0")
111162
override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
112163
}

mllib/src/main/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluator.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,19 @@ class MultilabelClassificationEvaluator (override val uid: String)
6363

6464
setDefault(metricName -> "f1Measure")
6565

66-
final val metricClass: DoubleParam = new DoubleParam(this, "metricClass",
67-
"The class whose metric will be computed in precisionByLabel|recallByLabel|" +
68-
"f1MeasureByLabel. Must be >= 0. The default value is 0.",
66+
final val metricLabel: DoubleParam = new DoubleParam(this, "metricLabel",
67+
"The class whose metric will be computed in " +
68+
s"${supportedMetricNames.filter(_.endsWith("ByLabel")).mkString("(", "|", ")")}. " +
69+
"Must be >= 0. The default value is 0.",
6970
ParamValidators.gtEq(0.0))
7071

7172
/** @group getParam */
72-
def getMetricClass: Double = $(metricClass)
73+
def getMetricLabel: Double = $(metricLabel)
7374

7475
/** @group setParam */
75-
def setMetricClass(value: Double): this.type = set(metricClass, value)
76+
def setMetricLabel(value: Double): this.type = set(metricLabel, value)
7677

77-
setDefault(metricClass -> 0.0)
78+
setDefault(metricLabel -> 0.0)
7879

7980
/** @group setParam */
8081
def setPredictionCol(value: String): this.type = set(predictionCol, value)
@@ -103,9 +104,9 @@ class MultilabelClassificationEvaluator (override val uid: String)
103104
case "precision" => metrics.precision
104105
case "recall" => metrics.recall
105106
case "f1Measure" => metrics.f1Measure
106-
case "precisionByLabel" => metrics.precision($(metricClass))
107-
case "recallByLabel" => metrics.recall($(metricClass))
108-
case "f1MeasureByLabel" => metrics.f1Measure($(metricClass))
107+
case "precisionByLabel" => metrics.precision($(metricLabel))
108+
case "recallByLabel" => metrics.recall($(metricLabel))
109+
case "f1MeasureByLabel" => metrics.f1Measure($(metricLabel))
109110
case "microPrecision" => metrics.microPrecision
110111
case "microRecall" => metrics.microRecall
111112
case "microF1Measure" => metrics.microF1Measure

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[_ <: Product])
230230
* Returns weighted averaged f1-measure
231231
*/
232232
@Since("1.1.0")
233-
lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) =>
234-
fMeasure(category, 1.0) * count.toDouble / labelCount
235-
}.sum
233+
lazy val weightedFMeasure: Double = weightedFMeasure(1.0)
236234

237235
/**
238236
* Returns the sequence of labels in ascending order

mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.ml.param.ParamsSuite
2222
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2323
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.mllib.util.TestingUtils._
2425

2526
class MulticlassClassificationEvaluatorSuite
2627
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
2728

29+
import testImplicits._
30+
2831
test("params") {
2932
ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
3033
}
@@ -34,10 +37,27 @@ class MulticlassClassificationEvaluatorSuite
3437
.setPredictionCol("myPrediction")
3538
.setLabelCol("myLabel")
3639
.setMetricName("accuracy")
40+
.setMetricLabel(1.0)
41+
.setBeta(2.0)
3742
testDefaultReadWrite(evaluator)
3843
}
3944

4045
test("should support all NumericType labels and not support other types") {
4146
MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark)
4247
}
48+
49+
test("evaluation metrics") {
50+
val predictionAndLabels = Seq((0.0, 0.0), (0.0, 1.0),
51+
(0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
52+
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)).toDF("prediction", "label")
53+
54+
val evaluator = new MulticlassClassificationEvaluator()
55+
.setMetricName("precisionByLabel")
56+
.setMetricLabel(0.0)
57+
assert(evaluator.evaluate(predictionAndLabels) ~== 2.0 / 3 absTol 1e-5)
58+
59+
evaluator.setMetricName("truePositiveRateByLabel")
60+
.setMetricLabel(1.0)
61+
assert(evaluator.evaluate(predictionAndLabels) ~== 3.0 / 4 absTol 1e-5)
62+
}
4363
}

mllib/src/test/scala/org/apache/spark/ml/evaluation/MultilabelClassificationEvaluatorSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ class MultilabelClassificationEvaluatorSuite
4747
assert(evaluator.evaluate(scoreAndLabels) ~== 2.0 / 7 absTol 1e-5)
4848

4949
evaluator.setMetricName("recallByLabel")
50-
.setMetricClass(0.0)
50+
.setMetricLabel(0.0)
5151
assert(evaluator.evaluate(scoreAndLabels) ~== 0.8 absTol 1e-5)
5252
}
5353

5454
test("read/write") {
5555
val evaluator = new MultilabelClassificationEvaluator()
5656
.setPredictionCol("myPrediction")
5757
.setLabelCol("myLabel")
58-
.setMetricClass(1.0)
58+
.setMetricLabel(1.0)
5959
.setMetricName("precisionByLabel")
6060
testDefaultReadWrite(evaluator)
6161
}

python/pyspark/ml/evaluation.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
292292
0.66...
293293
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
294294
0.66...
295+
>>> evaluator.evaluate(dataset, {evaluator.metricName: "truePositiveRateByLabel",
296+
... evaluator.metricLabel: 1.0})
297+
0.75...
295298
>>> mce_path = temp_path + "/mce"
296299
>>> evaluator.save(mce_path)
297300
>>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
@@ -313,20 +316,31 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
313316
"""
314317
metricName = Param(Params._dummy(), "metricName",
315318
"metric name in evaluation "
316-
"(f1|weightedPrecision|weightedRecall|accuracy)",
319+
"(f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate|"
320+
"weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel|"
321+
"falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel)",
317322
typeConverter=TypeConverters.toString)
323+
metricLabel = Param(Params._dummy(), "metricLabel",
324+
"The class whose metric will be computed in truePositiveRateByLabel|"
325+
"falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel."
326+
" Must be >= 0. The default value is 0.",
327+
typeConverter=TypeConverters.toFloat)
328+
beta = Param(Params._dummy(), "beta",
329+
"The beta value used in weightedFMeasure|fMeasureByLabel."
330+
" Must be > 0. The default value is 1.",
331+
typeConverter=TypeConverters.toFloat)
318332

319333
@keyword_only
320334
def __init__(self, predictionCol="prediction", labelCol="label",
321-
metricName="f1", weightCol=None):
335+
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0):
322336
"""
323337
__init__(self, predictionCol="prediction", labelCol="label", \
324-
metricName="f1", weightCol=None)
338+
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0)
325339
"""
326340
super(MulticlassClassificationEvaluator, self).__init__()
327341
self._java_obj = self._new_java_obj(
328342
"org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
329-
self._setDefault(metricName="f1")
343+
self._setDefault(metricName="f1", metricLabel=0.0, beta=1.0)
330344
kwargs = self._input_kwargs
331345
self._set(**kwargs)
332346

@@ -344,13 +358,41 @@ def getMetricName(self):
344358
"""
345359
return self.getOrDefault(self.metricName)
346360

361+
@since("3.0.0")
362+
def setMetricLabel(self, value):
363+
"""
364+
Sets the value of :py:attr:`metricLabel`.
365+
"""
366+
return self._set(metricLabel=value)
367+
368+
@since("3.0.0")
369+
def getMetricLabel(self):
370+
"""
371+
Gets the value of metricLabel or its default value.
372+
"""
373+
return self.getOrDefault(self.metricLabel)
374+
375+
@since("3.0.0")
376+
def setBeta(self, value):
377+
"""
378+
Sets the value of :py:attr:`beta`.
379+
"""
380+
return self._set(beta=value)
381+
382+
@since("3.0.0")
383+
def getBeta(self):
384+
"""
385+
Gets the value of beta or its default value.
386+
"""
387+
return self.getOrDefault(self.beta)
388+
347389
@keyword_only
348390
@since("1.5.0")
349391
def setParams(self, predictionCol="prediction", labelCol="label",
350-
metricName="f1", weightCol=None):
392+
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0):
351393
"""
352394
setParams(self, predictionCol="prediction", labelCol="label", \
353-
metricName="f1", weightCol=None)
395+
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0)
354396
Sets params for multiclass classification evaluator.
355397
"""
356398
kwargs = self._input_kwargs
@@ -390,23 +432,23 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
390432
"precisionByLabel|recallByLabel|f1MeasureByLabel|microPrecision|"
391433
"microRecall|microF1Measure)",
392434
typeConverter=TypeConverters.toString)
393-
metricClass = Param(Params._dummy(), "metricClass",
394-
"The label whose metric will be computed in precisionByLabel|"
435+
metricLabel = Param(Params._dummy(), "metricLabel",
436+
"The class whose metric will be computed in precisionByLabel|"
395437
"recallByLabel|f1MeasureByLabel. "
396438
"Must be >= 0. The default value is 0.",
397439
typeConverter=TypeConverters.toFloat)
398440

399441
@keyword_only
400442
def __init__(self, predictionCol="prediction", labelCol="label",
401-
metricName="f1Measure", metricClass=0.0):
443+
metricName="f1Measure", metricLabel=0.0):
402444
"""
403445
__init__(self, predictionCol="prediction", labelCol="label", \
404-
metricName="f1Measure", metricClass=0.0)
446+
metricName="f1Measure", metricLabel=0.0)
405447
"""
406448
super(MultilabelClassificationEvaluator, self).__init__()
407449
self._java_obj = self._new_java_obj(
408450
"org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid)
409-
self._setDefault(metricName="f1Measure", metricClass=0.0)
451+
self._setDefault(metricName="f1Measure", metricLabel=0.0)
410452
kwargs = self._input_kwargs
411453
self._set(**kwargs)
412454

@@ -425,26 +467,26 @@ def getMetricName(self):
425467
return self.getOrDefault(self.metricName)
426468

427469
@since("3.0.0")
428-
def setMetricClass(self, value):
470+
def setMetricLabel(self, value):
429471
"""
430-
Sets the value of :py:attr:`metricClass`.
472+
Sets the value of :py:attr:`metricLabel`.
431473
"""
432-
return self._set(metricClass=value)
474+
return self._set(metricLabel=value)
433475

434476
@since("3.0.0")
435-
def getMetricClass(self):
477+
def getMetricLabel(self):
436478
"""
437-
Gets the value of metricClass or its default value.
479+
Gets the value of metricLabel or its default value.
438480
"""
439-
return self.getOrDefault(self.metricClass)
481+
return self.getOrDefault(self.metricLabel)
440482

441483
@keyword_only
442484
@since("3.0.0")
443485
def setParams(self, predictionCol="prediction", labelCol="label",
444-
metricName="f1Measure", metricClass=0.0):
486+
metricName="f1Measure", metricLabel=0.0):
445487
"""
446488
setParams(self, predictionCol="prediction", labelCol="label", \
447-
metricName="f1Measure", metricClass=0.0)
489+
metricName="f1Measure", metricLabel=0.0)
448490
Sets params for multilabel classification evaluator.
449491
"""
450492
kwargs = self._input_kwargs

0 commit comments

Comments
 (0)