Skip to content

Commit ec3f4c5

Browse files
Yogesh GargRobert Kruszewski
authored andcommitted
[SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns
## What changes were proposed in this pull request? `handleInvalid` Param was forwarded to the VectorAssembler used by RFormula. ## How was this patch tested? added a test and ran all tests for RFormula and VectorAssembler Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Closes apache#20970 from yogeshg/spark_23562.
1 parent 62a6983 commit ec3f4c5

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
278278
encoderStages += new VectorAssembler(uid)
279279
.setInputCols(encodedTerms.toArray)
280280
.setOutputCol($(featuresCol))
281+
.setHandleInvalid($(handleInvalid))
281282
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
282283
encoderStages += new ColumnPruner(tempColumns.toSet)
283284

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.ml.attribute._
2122
import org.apache.spark.ml.linalg.{Vector, Vectors}
2223
import org.apache.spark.ml.param.ParamsSuite
@@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
592593
assert(features.toArray === a +: b.toArray)
593594
}
594595
}
596+
597+
test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") {
598+
val d1 = Seq(
599+
(1001L, "a"),
600+
(1002L, "b")).toDF("id1", "c1")
601+
val d2 = Seq[(java.lang.Long, String)](
602+
(20001L, "x"),
603+
(20002L, "y"),
604+
(null, null)).toDF("id2", "c2")
605+
val dataset = d1.crossJoin(d2)
606+
607+
def get_output(mode: String): DataFrame = {
608+
val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode)
609+
formula.fit(dataset).transform(dataset).select("features", "label")
610+
}
611+
612+
assert(intercept[SparkException](get_output("error").collect())
613+
.getMessage.contains("Encountered null while assembling a row"))
614+
assert(get_output("skip").count() == 4)
615+
assert(get_output("keep").count() == 6)
616+
}
617+
595618
}

0 commit comments

Comments
 (0)