Skip to content

Commit a8a5cd2

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-22009][ML] Using treeAggregate improve some algs
## What changes were proposed in this pull request? I test on a dataset of about 13M instances, and found that using `treeAggregate` give a speedup in following algs: |Algs| SpeedUp | |------|-----------| |OneHotEncoder| 5% | |StatFunctions.calculateCov| 7% | |StatFunctions.multipleApproxQuantiles| 9% | |RegressionEvaluator| 8% | ## How was this patch tested? existing tests Author: Zheng RuiFeng <[email protected]> Closes apache#19232 from zhengruifeng/use_treeAggregate.
1 parent b21b806 commit a8a5cd2

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
142142
if (outputAttrGroup.size < 0) {
143143
// If the number of attributes is unknown, we check the values from the input column.
144144
val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
145-
.aggregate(0.0)(
145+
.treeAggregate(0.0)(
146146
(m, x) => {
147147
assert(x <= Int.MaxValue,
148148
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")

mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class RegressionMetrics @Since("2.0.0") (
5454
private lazy val summary: MultivariateStatisticalSummary = {
5555
val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
5656
case (prediction, observation) => Vectors.dense(observation, observation - prediction)
57-
}.aggregate(new MultivariateOnlineSummarizer())(
57+
}.treeAggregate(new MultivariateOnlineSummarizer())(
5858
(summary, v) => summary.add(v),
5959
(sum1, sum2) => sum1.merge(sum2)
6060
)

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ object FrequentItems extends Logging {
9595
(name, originalSchema.fields(index).dataType)
9696
}.toArray
9797

98-
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
98+
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
9999
seqOp = (counts, row) => {
100100
var i = 0
101101
while (i < numCols) {

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ object StatFunctions extends Logging {
9999
sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = {
100100
sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) }
101101
}
102-
val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge)
102+
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge)
103103

104104
summaries.map { summary => probabilities.flatMap(summary.query) }
105105
}
@@ -160,7 +160,7 @@ object StatFunctions extends Logging {
160160
s"for columns with dataType ${data.get.dataType} not supported.")
161161
}
162162
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
163-
df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)(
163+
df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)(
164164
seqOp = (counter, row) => {
165165
counter.add(row.getDouble(0), row.getDouble(1))
166166
},

0 commit comments

Comments
 (0)