Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 514a7e6

Browse files
committed
[SPARK-20929][ML] LinearSVC should use its own threshold param
## What changes were proposed in this pull request? LinearSVC should use its own threshold param, rather than the shared one, since it applies to rawPrediction instead of probability. This PR changes the param in the Scala, Python and R APIs. ## How was this patch tested? New unit test to make sure the threshold can be set to any Double value. Author: Joseph K. Bradley <[email protected]> Closes apache#18151 from jkbradley/ml-2.2-linearsvc-cleanup. (cherry picked from commit cc67bd5) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent 8bf7f1e commit 514a7e6

File tree

4 files changed

+79
-5
lines changed

4 files changed

+79
-5
lines changed

R/pkg/R/mllib_classification.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
6262
#' of models will be always returned on the original scale, so it will be transparent for
6363
#' users. Note that with/without standardization, the models should be always converged
6464
#' to the same solution when no regularization is applied.
65-
#' @param threshold The threshold in binary classification, in range [0, 1].
65+
#' @param threshold The threshold in binary classification applied to the linear model prediction.
66+
#' This threshold can be any real number, where Inf will make all predictions 0.0
67+
#' and -Inf will make all predictions 1.0.
6668
#' @param weightCol The weight column name.
6769
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
6870
#' or the number of partitions are large, this param could be adjusted to a larger size.

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,23 @@ import org.apache.spark.sql.functions.{col, lit}
4242
/** Params for linear SVM Classifier. */
4343
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
4444
with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
45-
with HasThreshold with HasAggregationDepth
45+
with HasAggregationDepth {
46+
47+
/**
48+
* Param for threshold in binary classification prediction.
49+
* For LinearSVC, this threshold is applied to the rawPrediction, rather than a probability.
50+
* This threshold can be any real number, where Inf will make all predictions 0.0
51+
* and -Inf will make all predictions 1.0.
52+
* Default: 0.0
53+
*
54+
* @group param
55+
*/
56+
final val threshold: DoubleParam = new DoubleParam(this, "threshold",
57+
"threshold in binary classification prediction applied to rawPrediction")
58+
59+
/** @group getParam */
60+
def getThreshold: Double = $(threshold)
61+
}
4662

4763
/**
4864
* :: Experimental ::
@@ -126,7 +142,7 @@ class LinearSVC @Since("2.2.0") (
126142
def setWeightCol(value: String): this.type = set(weightCol, value)
127143

128144
/**
129-
* Set threshold in binary classification, in range [0, 1].
145+
* Set threshold in binary classification.
130146
*
131147
* @group setParam
132148
*/
@@ -284,6 +300,7 @@ class LinearSVCModel private[classification] (
284300

285301
@Since("2.2.0")
286302
def setThreshold(value: Double): this.type = set(threshold, value)
303+
setDefault(threshold, 0.0)
287304

288305
@Since("2.2.0")
289306
def setWeightCol(value: Double): this.type = set(threshold, value)
@@ -301,6 +318,10 @@ class LinearSVCModel private[classification] (
301318
Vectors.dense(-m, m)
302319
}
303320

321+
override protected def raw2prediction(rawPrediction: Vector): Double = {
322+
if (rawPrediction(1) > $(threshold)) 1.0 else 0.0
323+
}
324+
304325
@Since("2.2.0")
305326
override def copy(extra: ParamMap): LinearSVCModel = {
306327
copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent)

mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
2525
import org.apache.spark.ml.classification.LinearSVCSuite._
2626
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2727
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
28-
import org.apache.spark.ml.param.ParamsSuite
28+
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
2929
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
3030
import org.apache.spark.ml.util.TestingUtils._
3131
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -127,6 +127,39 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
127127
MLTestingUtils.checkCopyAndUids(lsvc, model)
128128
}
129129

130+
test("LinearSVC threshold acts on rawPrediction") {
131+
val lsvc =
132+
new LinearSVCModel(uid = "myLSVCM", coefficients = Vectors.dense(1.0), intercept = 0.0)
133+
val df = spark.createDataFrame(Seq(
134+
(1, Vectors.dense(1e-7)),
135+
(0, Vectors.dense(0.0)),
136+
(-1, Vectors.dense(-1e-7)))).toDF("id", "features")
137+
138+
def checkOneResult(
139+
model: LinearSVCModel,
140+
threshold: Double,
141+
expected: Set[(Int, Double)]): Unit = {
142+
model.setThreshold(threshold)
143+
val results = model.transform(df).select("id", "prediction").collect()
144+
.map(r => (r.getInt(0), r.getDouble(1)))
145+
.toSet
146+
assert(results === expected, s"Failed for threshold = $threshold")
147+
}
148+
149+
def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = {
150+
// Check via code path using Classifier.raw2prediction
151+
lsvc.setRawPredictionCol("rawPrediction")
152+
checkOneResult(lsvc, threshold, expected)
153+
// Check via code path using Classifier.predict
154+
lsvc.setRawPredictionCol("")
155+
checkOneResult(lsvc, threshold, expected)
156+
}
157+
158+
checkResults(0.0, Set((1, 1.0), (0, 0.0), (-1, 0.0)))
159+
checkResults(Double.PositiveInfinity, Set((1, 0.0), (0, 0.0), (-1, 0.0)))
160+
checkResults(Double.NegativeInfinity, Set((1, 1.0), (0, 1.0), (-1, 1.0)))
161+
}
162+
130163
test("linear svc doesn't fit intercept when fitIntercept is off") {
131164
val lsvc = new LinearSVC().setFitIntercept(false).setMaxIter(5)
132165
val model = lsvc.fit(smallBinaryDataset)

python/pyspark/ml/classification.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def numClasses(self):
6363
@inherit_doc
6464
class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
6565
HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization,
66-
HasThreshold, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
66+
HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
6767
"""
6868
.. note:: Experimental
6969
@@ -109,6 +109,12 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha
109109
.. versionadded:: 2.2.0
110110
"""
111111

112+
threshold = Param(Params._dummy(), "threshold",
113+
"The threshold in binary classification applied to the linear model"
114+
" prediction. This threshold can be any real number, where Inf will make"
115+
" all predictions 0.0 and -Inf will make all predictions 1.0.",
116+
typeConverter=TypeConverters.toFloat)
117+
112118
@keyword_only
113119
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
114120
maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
@@ -147,6 +153,18 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
147153
def _create_model(self, java_model):
148154
return LinearSVCModel(java_model)
149155

156+
def setThreshold(self, value):
157+
"""
158+
Sets the value of :py:attr:`threshold`.
159+
"""
160+
return self._set(threshold=value)
161+
162+
def getThreshold(self):
163+
"""
164+
Gets the value of threshold or its default value.
165+
"""
166+
return self.getOrDefault(self.threshold)
167+
150168

151169
class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
152170
"""

0 commit comments

Comments
 (0)