Skip to content

Commit f3676d6

Browse files
YY-OnCallyanboliang
authored andcommitted
[SPARK-21108][ML] convert LinearSVC to aggregator framework
## What changes were proposed in this pull request? convert LinearSVC to new aggregator framework ## How was this patch tested? existing unit test. Author: Yuhao Yang <[email protected]> Closes apache#18315 from hhbyyh/svcAggregator.
1 parent 05af2de commit f3676d6

File tree

5 files changed

+286
-195
lines changed

5 files changed

+286
-195
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 14 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ import org.apache.hadoop.fs.Path
2525

2626
import org.apache.spark.SparkException
2727
import org.apache.spark.annotation.{Experimental, Since}
28-
import org.apache.spark.broadcast.Broadcast
2928
import org.apache.spark.internal.Logging
3029
import org.apache.spark.ml.feature.Instance
3130
import org.apache.spark.ml.linalg._
32-
import org.apache.spark.ml.linalg.BLAS._
31+
import org.apache.spark.ml.optim.aggregator.HingeAggregator
32+
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
3333
import org.apache.spark.ml.param._
3434
import org.apache.spark.ml.param.shared._
3535
import org.apache.spark.ml.util._
@@ -214,10 +214,20 @@ class LinearSVC @Since("2.2.0") (
214214
}
215215

216216
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
217+
val getFeaturesStd = (j: Int) => featuresStd(j)
217218
val regParamL2 = $(regParam)
218219
val bcFeaturesStd = instances.context.broadcast(featuresStd)
219-
val costFun = new LinearSVCCostFun(instances, $(fitIntercept),
220-
$(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth))
220+
val regularization = if (regParamL2 != 0.0) {
221+
val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
222+
Some(new L2Regularization(regParamL2, shouldApply,
223+
if ($(standardization)) None else Some(getFeaturesStd)))
224+
} else {
225+
None
226+
}
227+
228+
val getAggregatorFunc = new HingeAggregator(bcFeaturesStd, $(fitIntercept))(_)
229+
val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization,
230+
$(aggregationDepth))
221231

222232
def regParamL1Fun = (index: Int) => 0D
223233
val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
@@ -372,189 +382,3 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
372382
}
373383
}
374384
}
375-
376-
/**
377-
* LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function
378-
*/
379-
private class LinearSVCCostFun(
380-
instances: RDD[Instance],
381-
fitIntercept: Boolean,
382-
standardization: Boolean,
383-
bcFeaturesStd: Broadcast[Array[Double]],
384-
regParamL2: Double,
385-
aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
386-
387-
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
388-
val coeffs = Vectors.fromBreeze(coefficients)
389-
val bcCoeffs = instances.context.broadcast(coeffs)
390-
val featuresStd = bcFeaturesStd.value
391-
val numFeatures = featuresStd.length
392-
393-
val svmAggregator = {
394-
val seqOp = (c: LinearSVCAggregator, instance: Instance) => c.add(instance)
395-
val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2)
396-
397-
instances.treeAggregate(
398-
new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept)
399-
)(seqOp, combOp, aggregationDepth)
400-
}
401-
402-
val totalGradientArray = svmAggregator.gradient.toArray
403-
// regVal is the sum of coefficients squares excluding intercept for L2 regularization.
404-
val regVal = if (regParamL2 == 0.0) {
405-
0.0
406-
} else {
407-
var sum = 0.0
408-
coeffs.foreachActive { case (index, value) =>
409-
// We do not apply regularization to the intercepts
410-
if (index != numFeatures) {
411-
// The following code will compute the loss of the regularization; also
412-
// the gradient of the regularization, and add back to totalGradientArray.
413-
sum += {
414-
if (standardization) {
415-
totalGradientArray(index) += regParamL2 * value
416-
value * value
417-
} else {
418-
if (featuresStd(index) != 0.0) {
419-
// If `standardization` is false, we still standardize the data
420-
// to improve the rate of convergence; as a result, we have to
421-
// perform this reverse standardization by penalizing each component
422-
// differently to get effectively the same objective function when
423-
// the training dataset is not standardized.
424-
val temp = value / (featuresStd(index) * featuresStd(index))
425-
totalGradientArray(index) += regParamL2 * temp
426-
value * temp
427-
} else {
428-
0.0
429-
}
430-
}
431-
}
432-
}
433-
}
434-
0.5 * regParamL2 * sum
435-
}
436-
bcCoeffs.destroy(blocking = false)
437-
438-
(svmAggregator.loss + regVal, new BDV(totalGradientArray))
439-
}
440-
}
441-
442-
/**
443-
* LinearSVCAggregator computes the gradient and loss for hinge loss function, as used
444-
* in binary classification for instances in sparse or dense vector in an online fashion.
445-
*
446-
* Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of
447-
* the corresponding joint dataset.
448-
*
449-
* This class standardizes feature values during computation using bcFeaturesStd.
450-
*
451-
* @param bcCoefficients The coefficients corresponding to the features.
452-
* @param fitIntercept Whether to fit an intercept term.
453-
* @param bcFeaturesStd The standard deviation values of the features.
454-
*/
455-
private class LinearSVCAggregator(
456-
bcCoefficients: Broadcast[Vector],
457-
bcFeaturesStd: Broadcast[Array[Double]],
458-
fitIntercept: Boolean) extends Serializable {
459-
460-
private val numFeatures: Int = bcFeaturesStd.value.length
461-
private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
462-
private var weightSum: Double = 0.0
463-
private var lossSum: Double = 0.0
464-
@transient private lazy val coefficientsArray = bcCoefficients.value match {
465-
case DenseVector(values) => values
466-
case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" +
467-
s" but got type ${bcCoefficients.value.getClass}.")
468-
}
469-
private lazy val gradientSumArray = new Array[Double](numFeaturesPlusIntercept)
470-
471-
/**
472-
* Add a new training instance to this LinearSVCAggregator, and update the loss and gradient
473-
* of the objective function.
474-
*
475-
* @param instance The instance of data point to be added.
476-
* @return This LinearSVCAggregator object.
477-
*/
478-
def add(instance: Instance): this.type = {
479-
instance match { case Instance(label, weight, features) =>
480-
481-
if (weight == 0.0) return this
482-
val localFeaturesStd = bcFeaturesStd.value
483-
val localCoefficients = coefficientsArray
484-
val localGradientSumArray = gradientSumArray
485-
486-
val dotProduct = {
487-
var sum = 0.0
488-
features.foreachActive { (index, value) =>
489-
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
490-
sum += localCoefficients(index) * value / localFeaturesStd(index)
491-
}
492-
}
493-
if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
494-
sum
495-
}
496-
// Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
497-
// Therefore the gradient is -(2y - 1)*x
498-
val labelScaled = 2 * label - 1.0
499-
val loss = if (1.0 > labelScaled * dotProduct) {
500-
weight * (1.0 - labelScaled * dotProduct)
501-
} else {
502-
0.0
503-
}
504-
505-
if (1.0 > labelScaled * dotProduct) {
506-
val gradientScale = -labelScaled * weight
507-
features.foreachActive { (index, value) =>
508-
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
509-
localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
510-
}
511-
}
512-
if (fitIntercept) {
513-
localGradientSumArray(localGradientSumArray.length - 1) += gradientScale
514-
}
515-
}
516-
517-
lossSum += loss
518-
weightSum += weight
519-
this
520-
}
521-
}
522-
523-
/**
524-
* Merge another LinearSVCAggregator, and update the loss and gradient
525-
* of the objective function.
526-
* (Note that it's in place merging; as a result, `this` object will be modified.)
527-
*
528-
* @param other The other LinearSVCAggregator to be merged.
529-
* @return This LinearSVCAggregator object.
530-
*/
531-
def merge(other: LinearSVCAggregator): this.type = {
532-
533-
if (other.weightSum != 0.0) {
534-
weightSum += other.weightSum
535-
lossSum += other.lossSum
536-
537-
var i = 0
538-
val localThisGradientSumArray = this.gradientSumArray
539-
val localOtherGradientSumArray = other.gradientSumArray
540-
val len = localThisGradientSumArray.length
541-
while (i < len) {
542-
localThisGradientSumArray(i) += localOtherGradientSumArray(i)
543-
i += 1
544-
}
545-
}
546-
this
547-
}
548-
549-
def loss: Double = if (weightSum != 0) lossSum / weightSum else 0.0
550-
551-
def gradient: Vector = {
552-
if (weightSum != 0) {
553-
val result = Vectors.dense(gradientSumArray.clone())
554-
scal(1.0 / weightSum, result)
555-
result
556-
} else {
557-
Vectors.dense(new Array[Double](numFeaturesPlusIntercept))
558-
}
559-
}
560-
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.optim.aggregator
19+
20+
import org.apache.spark.broadcast.Broadcast
21+
import org.apache.spark.ml.feature.Instance
22+
import org.apache.spark.ml.linalg._
23+
24+
/**
25+
* HingeAggregator computes the gradient and loss for Hinge loss function as used in
26+
* binary classification for instances in sparse or dense vector in an online fashion.
27+
*
28+
* Two HingeAggregators can be merged together to have a summary of loss and gradient of
29+
* the corresponding joint dataset.
30+
*
31+
* This class standardizes feature values during computation using bcFeaturesStd.
32+
*
33+
* @param bcCoefficients The coefficients corresponding to the features.
34+
* @param fitIntercept Whether to fit an intercept term.
35+
* @param bcFeaturesStd The standard deviation values of the features.
36+
*/
37+
private[ml] class HingeAggregator(
38+
bcFeaturesStd: Broadcast[Array[Double]],
39+
fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector])
40+
extends DifferentiableLossAggregator[Instance, HingeAggregator] {
41+
42+
private val numFeatures: Int = bcFeaturesStd.value.length
43+
private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
44+
@transient private lazy val coefficientsArray = bcCoefficients.value match {
45+
case DenseVector(values) => values
46+
case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" +
47+
s" but got type ${bcCoefficients.value.getClass}.")
48+
}
49+
protected override val dim: Int = numFeaturesPlusIntercept
50+
51+
/**
52+
* Add a new training instance to this HingeAggregator, and update the loss and gradient
53+
* of the objective function.
54+
*
55+
* @param instance The instance of data point to be added.
56+
* @return This HingeAggregator object.
57+
*/
58+
def add(instance: Instance): this.type = {
59+
instance match { case Instance(label, weight, features) =>
60+
require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
61+
s" Expecting $numFeatures but got ${features.size}.")
62+
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
63+
64+
if (weight == 0.0) return this
65+
val localFeaturesStd = bcFeaturesStd.value
66+
val localCoefficients = coefficientsArray
67+
val localGradientSumArray = gradientSumArray
68+
69+
val dotProduct = {
70+
var sum = 0.0
71+
features.foreachActive { (index, value) =>
72+
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
73+
sum += localCoefficients(index) * value / localFeaturesStd(index)
74+
}
75+
}
76+
if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
77+
sum
78+
}
79+
// Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
80+
// Therefore the gradient is -(2y - 1)*x
81+
val labelScaled = 2 * label - 1.0
82+
val loss = if (1.0 > labelScaled * dotProduct) {
83+
(1.0 - labelScaled * dotProduct) * weight
84+
} else {
85+
0.0
86+
}
87+
88+
if (1.0 > labelScaled * dotProduct) {
89+
val gradientScale = -labelScaled * weight
90+
features.foreachActive { (index, value) =>
91+
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
92+
localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
93+
}
94+
}
95+
if (fitIntercept) {
96+
localGradientSumArray(localGradientSumArray.length - 1) += gradientScale
97+
}
98+
}
99+
100+
lossSum += loss
101+
weightSum += weight
102+
this
103+
}
104+
}
105+
}

mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import org.apache.spark.SparkFunSuite
2525
import org.apache.spark.ml.classification.LinearSVCSuite._
2626
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2727
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
28-
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
28+
import org.apache.spark.ml.optim.aggregator.HingeAggregator
29+
import org.apache.spark.ml.param.ParamsSuite
2930
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
3031
import org.apache.spark.ml.util.TestingUtils._
3132
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -170,10 +171,10 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
170171
assert(model2.intercept !== 0.0)
171172
}
172173

173-
test("sparse coefficients in SVCAggregator") {
174+
test("sparse coefficients in HingeAggregator") {
174175
val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0)))
175176
val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0))
176-
val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true)
177+
val agg = new HingeAggregator(bcFeaturesStd, true)(bcCoefficients)
177178
val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") {
178179
intercept[IllegalArgumentException] {
179180
agg.add(Instance(1.0, 1.0, Vectors.dense(1.0)))

0 commit comments

Comments
 (0)