Skip to content

Commit f9927a5

Browse files
Yogesh GargRobert Kruszewski
authored andcommitted
[SPARK-23690][ML] Add handleinvalid to VectorAssembler
## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Author: Bago Amirbekian <[email protected]> Author: Yogesh Garg <[email protected]> Closes apache#20829 from yogeshg/rformula_handleinvalid.
1 parent 83c1da3 commit f9927a5

File tree

3 files changed

+284
-47
lines changed

3 files changed

+284
-47
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class StringIndexerModel (
234234
val metadata = NominalAttribute.defaultAttr
235235
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
236236
// If we are skipping invalid records, filter them out.
237-
val (filteredDataset, keepInvalid) = getHandleInvalid match {
237+
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
238238
case StringIndexer.SKIP_INVALID =>
239239
val filterer = udf { label: String =>
240240
labelToIndex.contains(label)

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 163 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import scala.collection.mutable.ArrayBuilder
20+
import java.util.NoSuchElementException
21+
22+
import scala.collection.mutable
23+
import scala.language.existentials
2124

2225
import org.apache.spark.SparkException
2326
import org.apache.spark.annotation.Since
2427
import org.apache.spark.ml.Transformer
2528
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
2629
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
27-
import org.apache.spark.ml.param.ParamMap
30+
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
2831
import org.apache.spark.ml.param.shared._
2932
import org.apache.spark.ml.util._
3033
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
3336

3437
/**
3538
* A feature transformer that merges multiple columns into a vector column.
39+
*
40+
* This requires one pass over the entire dataset. In case we need to infer column lengths from the
41+
* data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
3642
*/
3743
@Since("1.4.0")
3844
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
39-
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
45+
extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
46+
with DefaultParamsWritable {
4047

4148
@Since("1.4.0")
4249
def this() = this(Identifiable.randomUID("vecAssembler"))
@@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
4956
@Since("1.4.0")
5057
def setOutputCol(value: String): this.type = set(outputCol, value)
5158

59+
/** @group setParam */
60+
@Since("2.4.0")
61+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
62+
63+
/**
64+
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
65+
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
66+
* output). Column lengths are taken from the size of ML Attribute Group, which can be set using
67+
* `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
68+
* from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
69+
* Default: "error"
70+
* @group param
71+
*/
72+
@Since("2.4.0")
73+
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
74+
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
75+
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
76+
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using
77+
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
78+
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
79+
|""".stripMargin.replaceAll("\n", " "),
80+
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
81+
82+
setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
83+
5284
@Since("2.0.0")
5385
override def transform(dataset: Dataset[_]): DataFrame = {
5486
transformSchema(dataset.schema, logging = true)
5587
// Schema transformation.
5688
val schema = dataset.schema
57-
lazy val first = dataset.toDF.first()
58-
val attrs = $(inputCols).flatMap { c =>
89+
90+
val vectorCols = $(inputCols).filter { c =>
91+
schema(c).dataType match {
92+
case _: VectorUDT => true
93+
case _ => false
94+
}
95+
}
96+
val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
97+
98+
val featureAttributesMap = $(inputCols).map { c =>
5999
val field = schema(c)
60-
val index = schema.fieldIndex(c)
61100
field.dataType match {
62101
case DoubleType =>
63-
val attr = Attribute.fromStructField(field)
64-
// If the input column doesn't have ML attribute, assume numeric.
65-
if (attr == UnresolvedAttribute) {
66-
Some(NumericAttribute.defaultAttr.withName(c))
67-
} else {
68-
Some(attr.withName(c))
102+
val attribute = Attribute.fromStructField(field)
103+
attribute match {
104+
case UnresolvedAttribute =>
105+
Seq(NumericAttribute.defaultAttr.withName(c))
106+
case _ =>
107+
Seq(attribute.withName(c))
69108
}
70109
case _: NumericType | BooleanType =>
71110
// If the input column type is a compatible scalar type, assume numeric.
72-
Some(NumericAttribute.defaultAttr.withName(c))
111+
Seq(NumericAttribute.defaultAttr.withName(c))
73112
case _: VectorUDT =>
74-
val group = AttributeGroup.fromStructField(field)
75-
if (group.attributes.isDefined) {
76-
// If attributes are defined, copy them with updated names.
77-
group.attributes.get.zipWithIndex.map { case (attr, i) =>
113+
val attributeGroup = AttributeGroup.fromStructField(field)
114+
if (attributeGroup.attributes.isDefined) {
115+
attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
78116
if (attr.name.isDefined) {
79117
// TODO: Define a rigorous naming scheme.
80118
attr.withName(c + "_" + attr.name.get)
@@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
85123
} else {
86124
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
87125
// from metadata, check the first row.
88-
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
89-
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
126+
(0 until vectorColsLengths(c)).map { i =>
127+
NumericAttribute.defaultAttr.withName(c + "_" + i)
128+
}
90129
}
91130
case otherType =>
92131
throw new SparkException(s"VectorAssembler does not support the $otherType type")
93132
}
94133
}
95-
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
96-
134+
val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
135+
val lengths = featureAttributesMap.map(a => a.length).toArray
136+
val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
137+
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
138+
case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
139+
case VectorAssembler.KEEP_INVALID => (dataset, true)
140+
case VectorAssembler.ERROR_INVALID => (dataset, false)
141+
}
97142
// Data transformation.
98143
val assembleFunc = udf { r: Row =>
99-
VectorAssembler.assemble(r.toSeq: _*)
144+
VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
100145
}.asNondeterministic()
101146
val args = $(inputCols).map { c =>
102147
schema(c).dataType match {
@@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
106151
}
107152
}
108153

109-
dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
154+
filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
110155
}
111156

112157
@Since("1.4.0")
@@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
136181
@Since("1.6.0")
137182
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
138183

184+
private[feature] val SKIP_INVALID: String = "skip"
185+
private[feature] val ERROR_INVALID: String = "error"
186+
private[feature] val KEEP_INVALID: String = "keep"
187+
private[feature] val supportedHandleInvalids: Array[String] =
188+
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
189+
190+
/**
191+
* Infers lengths of vector columns from the first row of the dataset
192+
* @param dataset the dataset
193+
* @param columns name of vector columns whose lengths need to be inferred
194+
* @return map of column names to lengths
195+
*/
196+
private[feature] def getVectorLengthsFromFirstRow(
197+
dataset: Dataset[_],
198+
columns: Seq[String]): Map[String, Int] = {
199+
try {
200+
val first_row = dataset.toDF().select(columns.map(col): _*).first()
201+
columns.zip(first_row.toSeq).map {
202+
case (c, x) => c -> x.asInstanceOf[Vector].size
203+
}.toMap
204+
} catch {
205+
case e: NullPointerException => throw new NullPointerException(
206+
s"""Encountered null value while inferring lengths from the first row. Consider using
207+
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
208+
.stripMargin.replaceAll("\n", " ") + e.toString)
209+
case e: NoSuchElementException => throw new NoSuchElementException(
210+
s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
211+
|VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
212+
.stripMargin.replaceAll("\n", " ") + e.toString)
213+
}
214+
}
215+
216+
private[feature] def getLengths(
217+
dataset: Dataset[_],
218+
columns: Seq[String],
219+
handleInvalid: String): Map[String, Int] = {
220+
val groupSizes = columns.map { c =>
221+
c -> AttributeGroup.fromStructField(dataset.schema(c)).size
222+
}.toMap
223+
val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
224+
val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
225+
case (true, VectorAssembler.ERROR_INVALID) =>
226+
getVectorLengthsFromFirstRow(dataset, missingColumns)
227+
case (true, VectorAssembler.SKIP_INVALID) =>
228+
getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
229+
case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
230+
s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
231+
|to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
232+
.stripMargin.replaceAll("\n", " "))
233+
case (_, _) => Map.empty
234+
}
235+
groupSizes ++ firstSizes
236+
}
237+
238+
139239
@Since("1.6.0")
140240
override def load(path: String): VectorAssembler = super.load(path)
141241

142-
private[feature] def assemble(vv: Any*): Vector = {
143-
val indices = ArrayBuilder.make[Int]
144-
val values = ArrayBuilder.make[Double]
145-
var cur = 0
242+
/**
243+
* Returns a function that has the required information to assemble each row.
244+
* @param lengths an array of lengths of input columns, whose size should be equal to the number
245+
* of cells in the row (vv)
246+
* @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
247+
* @return a udf that can be applied on each row
248+
*/
249+
private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
250+
val indices = mutable.ArrayBuilder.make[Int]
251+
val values = mutable.ArrayBuilder.make[Double]
252+
var featureIndex = 0
253+
254+
var inputColumnIndex = 0
146255
vv.foreach {
147256
case v: Double =>
148-
if (v != 0.0) {
149-
indices += cur
257+
if (v.isNaN && !keepInvalid) {
258+
throw new SparkException(
259+
s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
260+
|removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
261+
.stripMargin)
262+
} else if (v != 0.0) {
263+
indices += featureIndex
150264
values += v
151265
}
152-
cur += 1
266+
inputColumnIndex += 1
267+
featureIndex += 1
153268
case vec: Vector =>
154269
vec.foreachActive { case (i, v) =>
155270
if (v != 0.0) {
156-
indices += cur + i
271+
indices += featureIndex + i
157272
values += v
158273
}
159274
}
160-
cur += vec.size
275+
inputColumnIndex += 1
276+
featureIndex += vec.size
161277
case null =>
162-
// TODO: output Double.NaN?
163-
throw new SparkException("Values to assemble cannot be null.")
278+
if (keepInvalid) {
279+
val length: Int = lengths(inputColumnIndex)
280+
Array.range(0, length).foreach { i =>
281+
indices += featureIndex + i
282+
values += Double.NaN
283+
}
284+
inputColumnIndex += 1
285+
featureIndex += length
286+
} else {
287+
throw new SparkException(
288+
s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
289+
|removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
290+
.stripMargin)
291+
}
164292
case o =>
165293
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
166294
}
167-
Vectors.sparse(cur, indices.result(), values.result()).compressed
295+
Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
168296
}
169297
}

0 commit comments

Comments
 (0)