17
17
18
18
package org .apache .spark .ml .feature
19
19
20
- import scala .collection .mutable .ArrayBuilder
20
+ import java .util .NoSuchElementException
21
+
22
+ import scala .collection .mutable
23
+ import scala .language .existentials
21
24
22
25
import org .apache .spark .SparkException
23
26
import org .apache .spark .annotation .Since
24
27
import org .apache .spark .ml .Transformer
25
28
import org .apache .spark .ml .attribute .{Attribute , AttributeGroup , NumericAttribute , UnresolvedAttribute }
26
29
import org .apache .spark .ml .linalg .{Vector , Vectors , VectorUDT }
27
- import org .apache .spark .ml .param .ParamMap
30
+ import org .apache .spark .ml .param .{ Param , ParamMap , ParamValidators }
28
31
import org .apache .spark .ml .param .shared ._
29
32
import org .apache .spark .ml .util ._
30
33
import org .apache .spark .sql .{DataFrame , Dataset , Row }
@@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
33
36
34
37
/**
35
38
* A feature transformer that merges multiple columns into a vector column.
39
+ *
40
+ * This requires one pass over the entire dataset. In case we need to infer column lengths from the
41
+ * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
36
42
*/
37
43
@ Since (" 1.4.0" )
38
44
class VectorAssembler @ Since (" 1.4.0" ) (@ Since (" 1.4.0" ) override val uid : String )
39
- extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
45
+ extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
46
+ with DefaultParamsWritable {
40
47
41
48
@ Since (" 1.4.0" )
42
49
def this () = this (Identifiable .randomUID(" vecAssembler" ))
@@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
49
56
@ Since (" 1.4.0" )
50
57
def setOutputCol (value : String ): this .type = set(outputCol, value)
51
58
59
+ /** @group setParam */
60
+ @ Since (" 2.4.0" )
61
+ def setHandleInvalid (value : String ): this .type = set(handleInvalid, value)
62
+
63
+ /**
64
+ * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
65
+ * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
66
+ * output). Column lengths are taken from the size of ML Attribute Group, which can be set using
67
+ * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
68
+ * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
69
+ * Default: "error"
70
+ * @group param
71
+ */
72
+ @ Since (" 2.4.0" )
73
+ override val handleInvalid : Param [String ] = new Param [String ](this , " handleInvalid" ,
74
+ """ Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
75
+ |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
76
+ |output). Column lengths are taken from the size of ML Attribute Group, which can be set using
77
+ |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
78
+ |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
79
+ |""" .stripMargin.replaceAll(" \n " , " " ),
80
+ ParamValidators .inArray(VectorAssembler .supportedHandleInvalids))
81
+
82
+ setDefault(handleInvalid, VectorAssembler .ERROR_INVALID )
83
+
52
84
@ Since (" 2.0.0" )
53
85
override def transform (dataset : Dataset [_]): DataFrame = {
54
86
transformSchema(dataset.schema, logging = true )
55
87
// Schema transformation.
56
88
val schema = dataset.schema
57
- lazy val first = dataset.toDF.first()
58
- val attrs = $(inputCols).flatMap { c =>
89
+
90
+ val vectorCols = $(inputCols).filter { c =>
91
+ schema(c).dataType match {
92
+ case _ : VectorUDT => true
93
+ case _ => false
94
+ }
95
+ }
96
+ val vectorColsLengths = VectorAssembler .getLengths(dataset, vectorCols, $(handleInvalid))
97
+
98
+ val featureAttributesMap = $(inputCols).map { c =>
59
99
val field = schema(c)
60
- val index = schema.fieldIndex(c)
61
100
field.dataType match {
62
101
case DoubleType =>
63
- val attr = Attribute .fromStructField(field)
64
- // If the input column doesn't have ML attribute, assume numeric.
65
- if (attr == UnresolvedAttribute ) {
66
- Some (NumericAttribute .defaultAttr.withName(c))
67
- } else {
68
- Some (attr .withName(c))
102
+ val attribute = Attribute .fromStructField(field)
103
+ attribute match {
104
+ case UnresolvedAttribute =>
105
+ Seq (NumericAttribute .defaultAttr.withName(c))
106
+ case _ =>
107
+ Seq (attribute .withName(c))
69
108
}
70
109
case _ : NumericType | BooleanType =>
71
110
// If the input column type is a compatible scalar type, assume numeric.
72
- Some (NumericAttribute .defaultAttr.withName(c))
111
+ Seq (NumericAttribute .defaultAttr.withName(c))
73
112
case _ : VectorUDT =>
74
- val group = AttributeGroup .fromStructField(field)
75
- if (group.attributes.isDefined) {
76
- // If attributes are defined, copy them with updated names.
77
- group.attributes.get.zipWithIndex.map { case (attr, i) =>
113
+ val attributeGroup = AttributeGroup .fromStructField(field)
114
+ if (attributeGroup.attributes.isDefined) {
115
+ attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
78
116
if (attr.name.isDefined) {
79
117
// TODO: Define a rigorous naming scheme.
80
118
attr.withName(c + " _" + attr.name.get)
@@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
85
123
} else {
86
124
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
87
125
// from metadata, check the first row.
88
- val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector ](index).size)
89
- Array .tabulate(numAttrs)(i => NumericAttribute .defaultAttr.withName(c + " _" + i))
126
+ (0 until vectorColsLengths(c)).map { i =>
127
+ NumericAttribute .defaultAttr.withName(c + " _" + i)
128
+ }
90
129
}
91
130
case otherType =>
92
131
throw new SparkException (s " VectorAssembler does not support the $otherType type " )
93
132
}
94
133
}
95
- val metadata = new AttributeGroup ($(outputCol), attrs).toMetadata()
96
-
134
+ val featureAttributes = featureAttributesMap.flatten[Attribute ].toArray
135
+ val lengths = featureAttributesMap.map(a => a.length).toArray
136
+ val metadata = new AttributeGroup ($(outputCol), featureAttributes).toMetadata()
137
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
138
+ case VectorAssembler .SKIP_INVALID => (dataset.na.drop($(inputCols)), false )
139
+ case VectorAssembler .KEEP_INVALID => (dataset, true )
140
+ case VectorAssembler .ERROR_INVALID => (dataset, false )
141
+ }
97
142
// Data transformation.
98
143
val assembleFunc = udf { r : Row =>
99
- VectorAssembler .assemble(r.toSeq: _* )
144
+ VectorAssembler .assemble(lengths, keepInvalid)( r.toSeq: _* )
100
145
}.asNondeterministic()
101
146
val args = $(inputCols).map { c =>
102
147
schema(c).dataType match {
@@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
106
151
}
107
152
}
108
153
109
- dataset .select(col(" *" ), assembleFunc(struct(args : _* )).as($(outputCol), metadata))
154
+ filteredDataset .select(col(" *" ), assembleFunc(struct(args : _* )).as($(outputCol), metadata))
110
155
}
111
156
112
157
@ Since (" 1.4.0" )
@@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
136
181
@ Since (" 1.6.0" )
137
182
object VectorAssembler extends DefaultParamsReadable [VectorAssembler ] {
138
183
184
+ private [feature] val SKIP_INVALID : String = " skip"
185
+ private [feature] val ERROR_INVALID : String = " error"
186
+ private [feature] val KEEP_INVALID : String = " keep"
187
+ private [feature] val supportedHandleInvalids : Array [String ] =
188
+ Array (SKIP_INVALID , ERROR_INVALID , KEEP_INVALID )
189
+
190
+ /**
191
+ * Infers lengths of vector columns from the first row of the dataset
192
+ * @param dataset the dataset
193
+ * @param columns name of vector columns whose lengths need to be inferred
194
+ * @return map of column names to lengths
195
+ */
196
+ private [feature] def getVectorLengthsFromFirstRow (
197
+ dataset : Dataset [_],
198
+ columns : Seq [String ]): Map [String , Int ] = {
199
+ try {
200
+ val first_row = dataset.toDF().select(columns.map(col): _* ).first()
201
+ columns.zip(first_row.toSeq).map {
202
+ case (c, x) => c -> x.asInstanceOf [Vector ].size
203
+ }.toMap
204
+ } catch {
205
+ case e : NullPointerException => throw new NullPointerException (
206
+ s """ Encountered null value while inferring lengths from the first row. Consider using
207
+ |VectorSizeHint to add metadata for columns: ${columns.mkString(" [" , " , " , " ]" )}. """
208
+ .stripMargin.replaceAll(" \n " , " " ) + e.toString)
209
+ case e : NoSuchElementException => throw new NoSuchElementException (
210
+ s """ Encountered empty dataframe while inferring lengths from the first row. Consider using
211
+ |VectorSizeHint to add metadata for columns: ${columns.mkString(" [" , " , " , " ]" )}. """
212
+ .stripMargin.replaceAll(" \n " , " " ) + e.toString)
213
+ }
214
+ }
215
+
216
+ private [feature] def getLengths (
217
+ dataset : Dataset [_],
218
+ columns : Seq [String ],
219
+ handleInvalid : String ): Map [String , Int ] = {
220
+ val groupSizes = columns.map { c =>
221
+ c -> AttributeGroup .fromStructField(dataset.schema(c)).size
222
+ }.toMap
223
+ val missingColumns = groupSizes.filter(_._2 == - 1 ).keys.toSeq
224
+ val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
225
+ case (true , VectorAssembler .ERROR_INVALID ) =>
226
+ getVectorLengthsFromFirstRow(dataset, missingColumns)
227
+ case (true , VectorAssembler .SKIP_INVALID ) =>
228
+ getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
229
+ case (true , VectorAssembler .KEEP_INVALID ) => throw new RuntimeException (
230
+ s """ Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
231
+ |to add metadata for columns: ${columns.mkString(" [" , " , " , " ]" )}. """
232
+ .stripMargin.replaceAll(" \n " , " " ))
233
+ case (_, _) => Map .empty
234
+ }
235
+ groupSizes ++ firstSizes
236
+ }
237
+
238
+
139
239
@ Since (" 1.6.0" )
140
240
override def load (path : String ): VectorAssembler = super .load(path)
141
241
142
- private [feature] def assemble (vv : Any * ): Vector = {
143
- val indices = ArrayBuilder .make[Int ]
144
- val values = ArrayBuilder .make[Double ]
145
- var cur = 0
242
+ /**
243
+ * Returns a function that has the required information to assemble each row.
244
+ * @param lengths an array of lengths of input columns, whose size should be equal to the number
245
+ * of cells in the row (vv)
246
+ * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
247
+ * @return a udf that can be applied on each row
248
+ */
249
+ private [feature] def assemble (lengths : Array [Int ], keepInvalid : Boolean )(vv : Any * ): Vector = {
250
+ val indices = mutable.ArrayBuilder .make[Int ]
251
+ val values = mutable.ArrayBuilder .make[Double ]
252
+ var featureIndex = 0
253
+
254
+ var inputColumnIndex = 0
146
255
vv.foreach {
147
256
case v : Double =>
148
- if (v != 0.0 ) {
149
- indices += cur
257
+ if (v.isNaN && ! keepInvalid) {
258
+ throw new SparkException (
259
+ s """ Encountered NaN while assembling a row with handleInvalid = "error". Consider
260
+ |removing NaNs from dataset or using handleInvalid = "keep" or "skip". """
261
+ .stripMargin)
262
+ } else if (v != 0.0 ) {
263
+ indices += featureIndex
150
264
values += v
151
265
}
152
- cur += 1
266
+ inputColumnIndex += 1
267
+ featureIndex += 1
153
268
case vec : Vector =>
154
269
vec.foreachActive { case (i, v) =>
155
270
if (v != 0.0 ) {
156
- indices += cur + i
271
+ indices += featureIndex + i
157
272
values += v
158
273
}
159
274
}
160
- cur += vec.size
275
+ inputColumnIndex += 1
276
+ featureIndex += vec.size
161
277
case null =>
162
- // TODO: output Double.NaN?
163
- throw new SparkException (" Values to assemble cannot be null." )
278
+ if (keepInvalid) {
279
+ val length : Int = lengths(inputColumnIndex)
280
+ Array .range(0 , length).foreach { i =>
281
+ indices += featureIndex + i
282
+ values += Double .NaN
283
+ }
284
+ inputColumnIndex += 1
285
+ featureIndex += length
286
+ } else {
287
+ throw new SparkException (
288
+ s """ Encountered null while assembling a row with handleInvalid = "keep". Consider
289
+ |removing nulls from dataset or using handleInvalid = "keep" or "skip". """
290
+ .stripMargin)
291
+ }
164
292
case o =>
165
293
throw new SparkException (s " $o of type ${o.getClass.getName} is not supported. " )
166
294
}
167
- Vectors .sparse(cur , indices.result(), values.result()).compressed
295
+ Vectors .sparse(featureIndex , indices.result(), values.result()).compressed
168
296
}
169
297
}
0 commit comments