Skip to content

Commit 0456b40

Browse files
WeichenXu123srowen
authored andcommitted
[SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.variance generate negative result
## What changes were proposed in this pull request? Because of numerical error, MultivariateOnlineSummarizer.variance is possible to generate negative variance. **This is a serious bug because many algos in MLLib** **use stddev computed from** `sqrt(variance)` **it will generate NaN and crash the whole algorithm.** we can reproduce this bug use the following code: ``` val summarizer1 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.7) val summarizer2 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.4) val summarizer3 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.5) val summarizer4 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.4) val summarizer = summarizer1 .merge(summarizer2) .merge(summarizer3) .merge(summarizer4) println(summarizer.variance(0)) ``` This PR fix the bugs in `mllib.stat.MultivariateOnlineSummarizer.variance` and `ml.stat.SummarizerBuffer.variance`, and several places in `WeightedLeastSquares` ## How was this patch tested? test cases added. Author: WeichenXu <[email protected]> Closes apache#19029 from WeichenXu123/fix_summarizer_var_bug.
1 parent 07142cf commit 0456b40

File tree

5 files changed

+51
-7
lines changed

5 files changed

+51
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
440440
/**
441441
* Weighted population standard deviation of labels.
442442
*/
443-
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
443+
def bStd: Double = {
444+
// We prevent variance from negative value caused by numerical error.
445+
val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
446+
math.sqrt(variance)
447+
}
444448

445449
/**
446450
* Weighted mean of (label * features).
@@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
471475
while (i < triK) {
472476
val l = j - 2
473477
val aw = aSum(l) / wSum
474-
std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
478+
// We prevent variance from negative value caused by numerical error.
479+
std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
475480
i += j
476481
j += 1
477482
}
@@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
489494
while (i < triK) {
490495
val l = j - 2
491496
val aw = aSum(l) / wSum
492-
variance(l) = aaValues(i) / wSum - aw * aw
497+
// We prevent variance from negative value caused by numerical error.
498+
variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
493499
i += j
494500
j += 1
495501
}

mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
436436
var i = 0
437437
val len = currM2n.length
438438
while (i < len) {
439-
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
440-
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
439+
// We prevent variance from negative value caused by numerical error.
440+
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
441+
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
441442
i += 1
442443
}
443444
}

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
213213
var i = 0
214214
val len = currM2n.length
215215
while (i < len) {
216-
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
217-
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
216+
// We prevent variance from negative value caused by numerical error.
217+
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
218+
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
218219
i += 1
219220
}
220221
}

mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
402402
assert(summarizer.count === 6)
403403
}
404404

405+
test("summarizer buffer zero variance test (SPARK-21818)") {
406+
val summarizer1 = new SummarizerBuffer()
407+
.add(Vectors.dense(3.0), 0.7)
408+
val summarizer2 = new SummarizerBuffer()
409+
.add(Vectors.dense(3.0), 0.4)
410+
val summarizer3 = new SummarizerBuffer()
411+
.add(Vectors.dense(3.0), 0.5)
412+
val summarizer4 = new SummarizerBuffer()
413+
.add(Vectors.dense(3.0), 0.4)
414+
415+
val summarizer = summarizer1
416+
.merge(summarizer2)
417+
.merge(summarizer3)
418+
.merge(summarizer4)
419+
420+
assert(summarizer.variance(0) >= 0.0)
421+
}
422+
405423
test("summarizer buffer merging summarizer with empty summarizer") {
406424
// If one of two is non-empty, this should return the non-empty summarizer.
407425
// If both of them are empty, then just return the empty summarizer.

mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
270270
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
271271
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
272272
}
273+
274+
test ("test zero variance (SPARK-21818)") {
275+
val summarizer1 = (new MultivariateOnlineSummarizer)
276+
.add(Vectors.dense(3.0), 0.7)
277+
val summarizer2 = (new MultivariateOnlineSummarizer)
278+
.add(Vectors.dense(3.0), 0.4)
279+
val summarizer3 = (new MultivariateOnlineSummarizer)
280+
.add(Vectors.dense(3.0), 0.5)
281+
val summarizer4 = (new MultivariateOnlineSummarizer)
282+
.add(Vectors.dense(3.0), 0.4)
283+
284+
val summarizer = summarizer1
285+
.merge(summarizer2)
286+
.merge(summarizer3)
287+
.merge(summarizer4)
288+
289+
assert(summarizer.variance(0) >= 0.0)
290+
}
273291
}

0 commit comments

Comments
 (0)