Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 0fa5b7c

Browse files
zhengruifengyanboliang
authored andcommitted
[SPARK-21690][ML] one-pass imputer
## What changes were proposed in this pull request? parallelize the computation of all columns performance tests: |numColums| Mean(Old) | Median(Old) | Mean(RDD) | Median(RDD) | Mean(DF) | Median(DF) | |------|----------|------------|----------|------------|----------|------------| |1|0.0771394713|0.0658712813|0.080779802|0.048165981499999996|0.10525509870000001|0.0499620203| |10|0.7234340630999999|0.5954440414|0.0867935197|0.13263428659999998|0.09255724889999999|0.1573943635| |100|7.3756451568|6.2196631259|0.1911931552|0.8625376817000001|0.5557462431|1.7216837982000002| ## How was this patch tested? existing tests Author: Zheng RuiFeng <[email protected]> Closes apache#18902 from zhengruifeng/parallelize_imputer.
1 parent ca00cc7 commit 0fa5b7c

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,23 +133,49 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
133133
override def fit(dataset: Dataset[_]): ImputerModel = {
134134
transformSchema(dataset.schema, logging = true)
135135
val spark = dataset.sparkSession
136-
import spark.implicits._
137-
val surrogates = $(inputCols).map { inputCol =>
138-
val ic = col(inputCol)
139-
val filtered = dataset.select(ic.cast(DoubleType))
140-
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
141-
if(filtered.take(1).length == 0) {
142-
throw new SparkException(s"surrogate cannot be computed. " +
143-
s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
144-
}
145-
val surrogate = $(strategy) match {
146-
case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
147-
case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
148-
}
149-
surrogate
136+
137+
val cols = $(inputCols).map { inputCol =>
138+
when(col(inputCol).equalTo($(missingValue)), null)
139+
.when(col(inputCol).isNaN, null)
140+
.otherwise(col(inputCol))
141+
.cast("double")
142+
.as(inputCol)
143+
}
144+
145+
val results = $(strategy) match {
146+
case Imputer.mean =>
147+
// Function avg will ignore null automatically.
148+
// For a column only containing null, avg will return null.
149+
val row = dataset.select(cols.map(avg): _*).head()
150+
Array.range(0, $(inputCols).length).map { i =>
151+
if (row.isNullAt(i)) {
152+
Double.NaN
153+
} else {
154+
row.getDouble(i)
155+
}
156+
}
157+
158+
case Imputer.median =>
159+
// Function approxQuantile will ignore null automatically.
160+
// For a column only containing null, approxQuantile will return an empty array.
161+
dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001)
162+
.map { array =>
163+
if (array.isEmpty) {
164+
Double.NaN
165+
} else {
166+
array.head
167+
}
168+
}
169+
}
170+
171+
val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1)
172+
if (emptyCols.nonEmpty) {
173+
throw new SparkException(s"surrogate cannot be computed. " +
174+
s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " +
175+
s"missingValue(${$(missingValue)})")
150176
}
151177

152-
val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
178+
val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results)))
153179
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
154180
val surrogateDF = spark.createDataFrame(rows, schema)
155181
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))

0 commit comments

Comments
 (0)