Skip to content

Commit c397b06

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-28045][ML][PYTHON] add missing RankingEvaluator
## What changes were proposed in this pull request? add missing RankingEvaluator ## How was this patch tested? added testsuites Closes apache#24869 from zhengruifeng/ranking_eval. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 731a60c commit c397b06

File tree

3 files changed

+274
-1
lines changed

3 files changed

+274
-1
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
19+
package org.apache.spark.ml.evaluation
20+
21+
import org.apache.spark.annotation.{Experimental, Since}
22+
import org.apache.spark.ml.param._
23+
import org.apache.spark.ml.param.shared._
24+
import org.apache.spark.ml.util._
25+
import org.apache.spark.mllib.evaluation.RankingMetrics
26+
import org.apache.spark.sql.Dataset
27+
import org.apache.spark.sql.functions._
28+
import org.apache.spark.sql.types._
29+
30+
/**
31+
* :: Experimental ::
32+
* Evaluator for ranking, which expects two input columns: prediction and label.
33+
*/
34+
@Experimental
35+
@Since("3.0.0")
36+
class RankingEvaluator (override val uid: String)
37+
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
38+
39+
import RankingEvaluator.supportedMetricNames
40+
41+
def this() = this(Identifiable.randomUID("rankEval"))
42+
43+
/**
44+
* param for metric name in evaluation (supports `"meanAveragePrecision"` (default),
45+
* `"meanAveragePrecisionAtK"`, `"precisionAtK"`, `"ndcgAtK"`, `"recallAtK"`)
46+
* @group param
47+
*/
48+
final val metricName: Param[String] = {
49+
val allowedParams = ParamValidators.inArray(supportedMetricNames)
50+
new Param(this, "metricName", "metric name in evaluation " +
51+
s"${supportedMetricNames.mkString("(", "|", ")")}", allowedParams)
52+
}
53+
54+
/** @group getParam */
55+
def getMetricName: String = $(metricName)
56+
57+
/** @group setParam */
58+
def setMetricName(value: String): this.type = set(metricName, value)
59+
60+
setDefault(metricName -> "meanAveragePrecision")
61+
62+
final val k = new IntParam(this, "k",
63+
"The ranking position value used in " +
64+
s"${supportedMetricNames.filter(_.endsWith("AtK")).mkString("(", "|", ")")} " +
65+
"Must be > 0. The default value is 10.",
66+
ParamValidators.gt(0))
67+
68+
/** @group getParam */
69+
def getK: Int = $(k)
70+
71+
/** @group setParam */
72+
def setK(value: Int): this.type = set(k, value)
73+
74+
setDefault(k -> 10)
75+
76+
/** @group setParam */
77+
def setPredictionCol(value: String): this.type = set(predictionCol, value)
78+
79+
/** @group setParam */
80+
def setLabelCol(value: String): this.type = set(labelCol, value)
81+
82+
83+
override def evaluate(dataset: Dataset[_]): Double = {
84+
val schema = dataset.schema
85+
SchemaUtils.checkColumnTypes(schema, $(predictionCol),
86+
Seq(ArrayType(DoubleType, false), ArrayType(DoubleType, true)))
87+
SchemaUtils.checkColumnTypes(schema, $(labelCol),
88+
Seq(ArrayType(DoubleType, false), ArrayType(DoubleType, true)))
89+
90+
val predictionAndLabels =
91+
dataset.select(col($(predictionCol)), col($(labelCol)))
92+
.rdd.map { row =>
93+
(row.getSeq[Double](0).toArray, row.getSeq[Double](1).toArray)
94+
}
95+
val metrics = new RankingMetrics[Double](predictionAndLabels)
96+
$(metricName) match {
97+
case "meanAveragePrecision" => metrics.meanAveragePrecision
98+
case "meanAveragePrecisionAtK" => metrics.meanAveragePrecisionAt($(k))
99+
case "precisionAtK" => metrics.precisionAt($(k))
100+
case "ndcgAtK" => metrics.ndcgAt($(k))
101+
case "recallAtK" => metrics.recallAt($(k))
102+
}
103+
}
104+
105+
override def isLargerBetter: Boolean = true
106+
107+
override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra)
108+
}
109+
110+
111+
@Since("3.0.0")
112+
object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator] {
113+
114+
private val supportedMetricNames = Array("meanAveragePrecision",
115+
"meanAveragePrecisionAtK", "precisionAtK", "ndcgAtK", "recallAtK")
116+
117+
override def load(path: String): RankingEvaluator = super.load(path)
118+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
19+
package org.apache.spark.ml.evaluation
20+
21+
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.ml.param.ParamsSuite
23+
import org.apache.spark.ml.util.DefaultReadWriteTest
24+
import org.apache.spark.mllib.util.MLlibTestSparkContext
25+
import org.apache.spark.mllib.util.TestingUtils._
26+
27+
class RankingEvaluatorSuite
28+
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
29+
30+
import testImplicits._
31+
32+
test("params") {
33+
ParamsSuite.checkParams(new RankingEvaluator)
34+
}
35+
36+
test("read/write") {
37+
val evaluator = new RankingEvaluator()
38+
.setPredictionCol("myPrediction")
39+
.setLabelCol("myLabel")
40+
.setMetricName("precisionAtK")
41+
.setK(10)
42+
testDefaultReadWrite(evaluator)
43+
}
44+
45+
test("evaluation metrics") {
46+
val scoreAndLabels = Seq(
47+
(Array(1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0),
48+
Array(1.0, 2.0, 3.0, 4.0, 5.0)),
49+
(Array(4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0),
50+
Array(1.0, 2.0, 3.0)),
51+
(Array(1.0, 2.0, 3.0, 4.0, 5.0), Array.empty[Double])
52+
).toDF("prediction", "label")
53+
54+
val evaluator = new RankingEvaluator()
55+
.setMetricName("meanAveragePrecision")
56+
assert(evaluator.evaluate(scoreAndLabels) ~== 0.355026 absTol 1e-5)
57+
58+
evaluator.setMetricName("precisionAtK")
59+
.setK(2)
60+
assert(evaluator.evaluate(scoreAndLabels) ~== 1.0 / 3 absTol 1e-5)
61+
}
62+
}

python/pyspark/ml/evaluation.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
3030
'MulticlassClassificationEvaluator', 'MultilabelClassificationEvaluator',
31-
'ClusteringEvaluator']
31+
'ClusteringEvaluator', 'RankingEvaluator']
3232

3333

3434
@inherit_doc
@@ -587,6 +587,99 @@ def getDistanceMeasure(self):
587587
return self.getOrDefault(self.distanceMeasure)
588588

589589

590+
@inherit_doc
591+
class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
592+
JavaMLReadable, JavaMLWritable):
593+
"""
594+
.. note:: Experimental
595+
596+
Evaluator for Ranking, which expects two input
597+
columns: prediction and label.
598+
599+
>>> scoreAndLabels = [([1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0],
600+
... [1.0, 2.0, 3.0, 4.0, 5.0]),
601+
... ([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0, 3.0]),
602+
... ([1.0, 2.0, 3.0, 4.0, 5.0], [])]
603+
>>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
604+
...
605+
>>> evaluator = RankingEvaluator(predictionCol="prediction")
606+
>>> evaluator.evaluate(dataset)
607+
0.35...
608+
>>> evaluator.evaluate(dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2})
609+
0.33...
610+
>>> ranke_path = temp_path + "/ranke"
611+
>>> evaluator.save(ranke_path)
612+
>>> evaluator2 = RankingEvaluator.load(ranke_path)
613+
>>> str(evaluator2.getPredictionCol())
614+
'prediction'
615+
616+
.. versionadded:: 3.0.0
617+
"""
618+
metricName = Param(Params._dummy(), "metricName",
619+
"metric name in evaluation "
620+
"(meanAveragePrecision|meanAveragePrecisionAtK|"
621+
"precisionAtK|ndcgAtK|recallAtK)",
622+
typeConverter=TypeConverters.toString)
623+
k = Param(Params._dummy(), "k",
624+
"The ranking position value used in meanAveragePrecisionAtK|precisionAtK|"
625+
"ndcgAtK|recallAtK. Must be > 0. The default value is 10.",
626+
typeConverter=TypeConverters.toInt)
627+
628+
@keyword_only
629+
def __init__(self, predictionCol="prediction", labelCol="label",
630+
metricName="meanAveragePrecision", k=10):
631+
"""
632+
__init__(self, predictionCol="prediction", labelCol="label", \
633+
metricName="meanAveragePrecision", k=10)
634+
"""
635+
super(RankingEvaluator, self).__init__()
636+
self._java_obj = self._new_java_obj(
637+
"org.apache.spark.ml.evaluation.RankingEvaluator", self.uid)
638+
self._setDefault(metricName="meanAveragePrecision", k=10)
639+
kwargs = self._input_kwargs
640+
self._set(**kwargs)
641+
642+
@since("3.0.0")
643+
def setMetricName(self, value):
644+
"""
645+
Sets the value of :py:attr:`metricName`.
646+
"""
647+
return self._set(metricName=value)
648+
649+
@since("3.0.0")
650+
def getMetricName(self):
651+
"""
652+
Gets the value of metricName or its default value.
653+
"""
654+
return self.getOrDefault(self.metricName)
655+
656+
@since("3.0.0")
657+
def setK(self, value):
658+
"""
659+
Sets the value of :py:attr:`k`.
660+
"""
661+
return self._set(k=value)
662+
663+
@since("3.0.0")
664+
def getK(self):
665+
"""
666+
Gets the value of k or its default value.
667+
"""
668+
return self.getOrDefault(self.k)
669+
670+
@keyword_only
671+
@since("3.0.0")
672+
def setParams(self, predictionCol="prediction", labelCol="label",
673+
metricName="meanAveragePrecision", k=10):
674+
"""
675+
setParams(self, predictionCol="prediction", labelCol="label", \
676+
metricName="meanAveragePrecision", k=10)
677+
Sets params for ranking evaluator.
678+
"""
679+
kwargs = self._input_kwargs
680+
return self._set(**kwargs)
681+
682+
590683
if __name__ == "__main__":
591684
import doctest
592685
import tempfile

0 commit comments

Comments
 (0)