Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 3b5c2a8

Browse files
WeichenXu123srowen
authored andcommitted
[SPARK-21770][ML] ProbabilisticClassificationModel fix corner case: normalization of all-zero raw predictions
## What changes were proposed in this pull request? Fix probabilisticClassificationModel corner case: normalization of all-zero raw predictions, throw IllegalArgumentException with description. ## How was this patch tested? Test case added. Author: WeichenXu <[email protected]> Closes apache#19106 from WeichenXu123/SPARK-21770.
1 parent af8a34c commit 3b5c2a8

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -230,21 +230,23 @@ private[ml] object ProbabilisticClassificationModel {
230230
* Normalize a vector of raw predictions to be a multinomial probability vector, in place.
231231
*
232232
* The input raw predictions should be nonnegative.
233-
* The output vector sums to 1, unless the input vector is all-0 (in which case the output is
234-
* all-0 too).
233+
* The output vector sums to 1.
235234
*
236235
* NOTE: This is NOT applicable to all models, only ones which effectively use class
237236
* instance counts for raw predictions.
237+
*
238+
* @throws IllegalArgumentException if the input vector is all-0 or including negative values
238239
*/
239240
def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = {
241+
v.values.foreach(value => require(value >= 0,
242+
"The input raw predictions should be nonnegative."))
240243
val sum = v.values.sum
241-
if (sum != 0) {
242-
var i = 0
243-
val size = v.size
244-
while (i < size) {
245-
v.values(i) /= sum
246-
i += 1
247-
}
244+
require(sum > 0, "Can't normalize the 0-vector.")
245+
var i = 0
246+
val size = v.size
247+
while (i < size) {
248+
v.values(i) /= sum
249+
i += 1
248250
}
249251
}
250252
}

mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
8080
new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1))
8181
}
8282
}
83+
84+
test("normalizeToProbabilitiesInPlace") {
85+
val vec1 = Vectors.dense(1.0, 2.0, 3.0).toDense
86+
ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec1)
87+
assert(vec1 ~== Vectors.dense(1.0 / 6, 2.0 / 6, 3.0 / 6) relTol 1e-3)
88+
89+
// all-0 input test
90+
val vec2 = Vectors.dense(0.0, 0.0, 0.0).toDense
91+
intercept[IllegalArgumentException] {
92+
ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec2)
93+
}
94+
95+
// negative input test
96+
val vec3 = Vectors.dense(1.0, -1.0, 2.0).toDense
97+
intercept[IllegalArgumentException] {
98+
ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec3)
99+
}
100+
}
83101
}
84102

85103
object ProbabilisticClassifierSuite {

0 commit comments

Comments
 (0)