Skip to content

Commit 5003736

Browse files
lu-wang-dljkbradley
authored andcommitted
[SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel
add RawPrediction as output column add numClasses and numFeatures to OneVsRestModel ## What changes were proposed in this pull request? - Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future - Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG <[email protected]> Closes apache#21044 from ludatabricks/SPARK-9312.
1 parent 083cf22 commit 5003736

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
3232
import org.apache.spark.annotation.Since
3333
import org.apache.spark.ml._
3434
import org.apache.spark.ml.attribute._
35-
import org.apache.spark.ml.linalg.Vector
35+
import org.apache.spark.ml.linalg.{Vector, Vectors}
3636
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
3737
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
3838
import org.apache.spark.ml.util._
@@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
5555
/**
5656
* Params for [[OneVsRest]].
5757
*/
58-
private[ml] trait OneVsRestParams extends PredictorParams
58+
private[ml] trait OneVsRestParams extends ClassifierParams
5959
with ClassifierTypeTrait with HasWeightCol {
6060

6161
/**
@@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] (
138138
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
139139
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
140140

141+
require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")
142+
143+
@Since("2.4.0")
144+
val numClasses: Int = models.length
145+
146+
@Since("2.4.0")
147+
val numFeatures: Int = models.head.numFeatures
148+
141149
/** @group setParam */
142150
@Since("2.1.0")
143151
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] (
146154
@Since("2.1.0")
147155
def setPredictionCol(value: String): this.type = set(predictionCol, value)
148156

157+
/** @group setParam */
158+
@Since("2.4.0")
159+
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
160+
149161
@Since("1.4.0")
150162
override def transformSchema(schema: StructType): StructType = {
151163
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] (
181193
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
182194
predictions + ((index, prediction(1)))
183195
}
196+
184197
model.setFeaturesCol($(featuresCol))
185198
val transformedDataset = model.transform(df).select(columns: _*)
186199
val updatedDataset = transformedDataset
@@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] (
195208
newDataset.unpersist()
196209
}
197210

198-
// output the index of the classifier with highest confidence as prediction
199-
val labelUDF = udf { (predictions: Map[Int, Double]) =>
200-
predictions.maxBy(_._2)._1.toDouble
201-
}
211+
if (getRawPredictionCol != "") {
212+
val numClass = models.length
202213

203-
// output label and label metadata as prediction
204-
aggregatedDataset
205-
.withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
206-
.drop(accColName)
214+
// output the RawPrediction as vector
215+
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
216+
val predArray = Array.fill[Double](numClass)(0.0)
217+
predictions.foreach { case (idx, value) => predArray(idx) = value }
218+
Vectors.dense(predArray)
219+
}
220+
221+
// output the index of the classifier with highest confidence as prediction
222+
val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }
223+
224+
// output confidence as raw prediction, label and label metadata as prediction
225+
aggregatedDataset
226+
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
227+
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
228+
.drop(accColName)
229+
} else {
230+
// output the index of the classifier with highest confidence as prediction
231+
val labelUDF = udf { (predictions: Map[Int, Double]) =>
232+
predictions.maxBy(_._2)._1.toDouble
233+
}
234+
// output label and label metadata as prediction
235+
aggregatedDataset
236+
.withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
237+
.drop(accColName)
238+
}
207239
}
208240

209241
@Since("1.4.1")
@@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
297329
@Since("1.5.0")
298330
def setPredictionCol(value: String): this.type = set(predictionCol, value)
299331

332+
/** @group setParam */
333+
@Since("2.4.0")
334+
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
335+
300336
/**
301337
* The implementation of parallel one vs. rest runs the classification for
302338
* each class in a separate threads.

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
7272
.setClassifier(new LogisticRegression)
7373
assert(ova.getLabelCol === "label")
7474
assert(ova.getPredictionCol === "prediction")
75+
assert(ova.getRawPredictionCol === "rawPrediction")
7576
val ovaModel = ova.fit(dataset)
7677

7778
MLTestingUtils.checkCopyAndUids(ova, ovaModel)
7879

79-
assert(ovaModel.models.length === numClasses)
80+
assert(ovaModel.numClasses === numClasses)
8081
val transformedDataset = ovaModel.transform(dataset)
8182

8283
// check for label metadata in prediction col
@@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
179180
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
180181
ovaModel.setFeaturesCol("fea")
181182
ovaModel.setPredictionCol("pred")
183+
ovaModel.setRawPredictionCol("")
182184
val transformedDataset = ovaModel.transform(dataset2)
183185
val outputFields = transformedDataset.schema.fieldNames.toSet
184186
assert(outputFields === Set("y", "fea", "pred"))
@@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
190192
val ovr = new OneVsRest()
191193
.setClassifier(logReg)
192194
val output = ovr.fit(dataset).transform(dataset)
193-
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
195+
assert(output.schema.fieldNames.toSet
196+
=== Set("label", "features", "prediction", "rawPrediction"))
194197
}
195198

196199
test("SPARK-21306: OneVsRest should support setWeightCol") {

0 commit comments

Comments
 (0)