Skip to content

Commit 6ad8d4c

Browse files
mgaido91srowen
authored andcommitted
[SPARK-25289][ML] Avoid exception in ChiSqSelector with FDR when no feature is selected
## What changes were proposed in this pull request? Currently, when FDR is used for `ChiSqSelector` and no feature is selected an exception is thrown because the max operation fails. The PR fixes the problem by handling this case and returning an empty array in that case, as sklearn (which was the reference for the initial implementation of FDR) does. ## How was this patch tested? added UT Closes apache#22303 from mgaido91/SPARK-25289. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 7c36ee4 commit 6ad8d4c

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Since
2828
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2929
import org.apache.spark.mllib.regression.LabeledPoint
3030
import org.apache.spark.mllib.stat.Statistics
31+
import org.apache.spark.mllib.stat.test.ChiSqTestResult
3132
import org.apache.spark.mllib.util.{Loader, Saveable}
3233
import org.apache.spark.rdd.RDD
3334
import org.apache.spark.sql.{Row, SparkSession}
@@ -272,13 +273,16 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
272273
// https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure
273274
val tempRes = chiSqTestResult
274275
.sortBy { case (res, _) => res.pValue }
275-
val maxIndex = tempRes
276+
val selected = tempRes
276277
.zipWithIndex
277278
.filter { case ((res, _), index) =>
278279
res.pValue <= fdr * (index + 1) / chiSqTestResult.length }
279-
.map { case (_, index) => index }
280-
.max
281-
tempRes.take(maxIndex + 1)
280+
if (selected.isEmpty) {
281+
Array.empty[(ChiSqTestResult, Int)]
282+
} else {
283+
val maxIndex = selected.map(_._2).max
284+
tempRes.take(maxIndex + 1)
285+
}
282286
case ChiSqSelector.FWE =>
283287
chiSqTestResult
284288
.filter { case (res, _) => res.pValue < fwe / chiSqTestResult.length }

mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,17 @@ class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest {
163163
}
164164
}
165165

166+
test("SPARK-25289: ChiSqSelector should not fail when selecting no features with FDR") {
167+
val labeledPoints = (0 to 1).map { n =>
168+
val v = Vectors.dense((1 to 3).map(_ => n * 1.0).toArray)
169+
(n.toDouble, v)
170+
}
171+
val inputDF = spark.createDataFrame(labeledPoints).toDF("label", "features")
172+
val selector = new ChiSqSelector().setSelectorType("fdr").setFdr(0.05)
173+
val model = selector.fit(inputDF)
174+
assert(model.selectedFeatures.isEmpty)
175+
}
176+
166177
private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = {
167178
val selectorModel = selector.fit(data)
168179
testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel,

0 commit comments

Comments
 (0)