Skip to content

Commit b1857a4

Browse files
Venkata krishnan Sowrirajanmaropu
authored andcommitted
[SPARK-26894][SQL] Handle Alias as well in AggregateEstimation to propagate child stats
## What changes were proposed in this pull request? Currently aliases are not handled in the Aggregate Estimation due to which stats are not getting propagated. This causes CBO join-reordering to not give optimal join plans. ProjectEstimation is already taking care of aliases, we need same logic for AggregateEstimation as well to properly propagate stats when CBO is enabled. ## How was this patch tested? This patch is manually tested using the query Q83 of TPCDS benchmark (scale 1000) Closes apache#23803 from venkata91/aggstats. Authored-by: Venkata krishnan Sowrirajan <[email protected]> Signed-off-by: Takeshi Yamamuro <[email protected]>
1 parent c26379b commit b1857a4

File tree

4 files changed

+45
-10
lines changed

4 files changed

+45
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
1919

20-
import org.apache.spark.sql.catalyst.expressions.Attribute
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
2121
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
2222

2323

@@ -52,7 +52,10 @@ object AggregateEstimation {
5252
outputRows.min(childStats.rowCount.get)
5353
}
5454

55-
val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
55+
val aliasStats = EstimationUtils.getAliasStats(agg.expressions, childStats.attributeStats)
56+
57+
val outputAttrStats = getOutputMap(
58+
AttributeMap(childStats.attributeStats.toSeq ++ aliasStats), agg.output)
5659
Some(Statistics(
5760
sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
5861
rowCount = Some(outputRows),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
2020
import scala.collection.mutable.ArrayBuffer
2121
import scala.math.BigDecimal.RoundingMode
2222

23-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
23+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression}
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.types.{DecimalType, _}
2626

@@ -71,6 +71,18 @@ object EstimationUtils {
7171
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
7272
}
7373

74+
/**
75+
* Returns the stats for aliases of child's attributes
76+
*/
77+
def getAliasStats(
78+
expressions: Seq[Expression],
79+
attributeStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = {
80+
expressions.collect {
81+
case alias @ Alias(attr: Attribute, _) if attributeStats.contains(attr) =>
82+
alias.toAttribute -> attributeStats(attr)
83+
}
84+
}
85+
7486
def getSizePerRow(
7587
attributes: Seq[Attribute],
7688
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,10 @@ object ProjectEstimation {
2626
def estimate(project: Project): Option[Statistics] = {
2727
if (rowCountsExist(project.child)) {
2828
val childStats = project.child.stats
29-
val inputAttrStats = childStats.attributeStats
30-
// Match alias with its child's column stat
31-
val aliasStats = project.expressions.collect {
32-
case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) =>
33-
alias.toAttribute -> inputAttrStats(attr)
34-
}
29+
val aliasStats = EstimationUtils.getAliasStats(project.expressions, childStats.attributeStats)
30+
3531
val outputAttrStats =
36-
getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
32+
getOutputMap(AttributeMap(childStats.attributeStats.toSeq ++ aliasStats), project.output)
3733
Some(childStats.copy(
3834
sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats),
3935
attributeStats = outputAttrStats))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,30 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest {
4545
private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
4646
columnInfo.map(kv => kv._1.name -> kv)
4747

48+
test("SPARK-26894: propagate child stats for aliases in Aggregate") {
49+
val tableColumns = Seq("key11", "key12")
50+
val groupByColumns = Seq("key11")
51+
val attributes = groupByColumns.map(nameToAttr)
52+
53+
val rowCount = 2
54+
val child = StatsTestPlan(
55+
outputList = tableColumns.map(nameToAttr),
56+
rowCount,
57+
// rowCount * (overhead + column size)
58+
size = Some(4 * (8 + 4)),
59+
attributeStats = AttributeMap(tableColumns.map(nameToColInfo)))
60+
61+
val testAgg = Aggregate(
62+
groupingExpressions = attributes,
63+
aggregateExpressions = Seq(Alias(nameToAttr("key12"), "abc")()),
64+
child)
65+
66+
val expectedColStats = Seq("abc" -> nameToColInfo("key12")._2)
67+
val expectedAttrStats = toAttributeMap(expectedColStats, testAgg)
68+
69+
assert(testAgg.stats.attributeStats == expectedAttrStats)
70+
}
71+
4872
test("set an upper bound if the product of ndv's of group-by columns is too large") {
4973
// Suppose table1 (key11 int, key12 int) has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
5074
checkAggStats(

0 commit comments

Comments
 (0)