Skip to content

Commit 65a4bc1

Browse files
dilipbiswalgatorsmile
authored andcommitted
[SPARK-21274][SQL] Implement INTERSECT ALL clause
## What changes were proposed in this pull request? Implements INTERSECT ALL clause through query rewrites using existing operators in Spark. Please refer to [Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE) for the design. Input Query ``` SQL SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2 ``` Rewritten Query ```SQL SELECT c1 FROM ( SELECT replicate_row(min_count, c1) FROM ( SELECT c1, IF (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count FROM ( SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt FROM ( SELECT c1, true as vcol1, null as vcol2 FROM ut1 UNION ALL SELECT c1, null as vcol1, true as vcol2 FROM ut2 ) AS union_all GROUP BY c1 HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1 ) ) ) ``` ## How was this patch tested? Added test cases in SQLQueryTestSuite, DataFrameSuite, SetOperationSuite Author: Dilip Biswal <[email protected]> Closes apache#21886 from dilipbiswal/dkb_intersect_all_final.
1 parent 6690924 commit 65a4bc1

File tree

15 files changed

+599
-12
lines changed

15 files changed

+599
-12
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,28 @@ def intersect(self, other):
15001500
"""
15011501
return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
15021502

1503+
@since(2.4)
1504+
def intersectAll(self, other):
1505+
""" Return a new :class:`DataFrame` containing rows in both this dataframe and other
1506+
dataframe while preserving duplicates.
1507+
1508+
This is equivalent to `INTERSECT ALL` in SQL.
1509+
>>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"])
1510+
>>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"])
1511+
1512+
>>> df1.intersectAll(df2).sort("C1", "C2").show()
1513+
+---+---+
1514+
| C1| C2|
1515+
+---+---+
1516+
| a| 1|
1517+
| a| 1|
1518+
| b| 3|
1519+
+---+---+
1520+
1521+
Also as standard in SQL, this function resolves columns by position (not by name).
1522+
"""
1523+
return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)
1524+
15031525
@since(1.3)
15041526
def subtract(self, other):
15051527
""" Return a new :class:`DataFrame` containing rows in this frame

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ class Analyzer(
914914
// To resolve duplicate expression IDs for Join and Intersect
915915
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
916916
j.copy(right = dedupRight(left, right))
917-
case i @ Intersect(left, right) if !i.duplicateResolved =>
917+
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
918918
i.copy(right = dedupRight(left, right))
919919
case e @ Except(left, right, _) if !e.duplicateResolved =>
920920
e.copy(right = dedupRight(left, right))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,11 @@ object TypeCoercion {
325325
assert(newChildren.length == 2)
326326
Except(newChildren.head, newChildren.last, isAll)
327327

328-
case s @ Intersect(left, right) if s.childrenResolved &&
328+
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
329329
left.output.length == right.output.length && !s.resolved =>
330330
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
331331
assert(newChildren.length == 2)
332-
Intersect(newChildren.head, newChildren.last)
332+
Intersect(newChildren.head, newChildren.last, isAll)
333333

334334
case s: Union if s.childrenResolved &&
335335
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ object UnsupportedOperationChecker {
309309
case Except(left, right, _) if right.isStreaming =>
310310
throwError("Except on a streaming DataFrame/Dataset on the right is not supported")
311311

312-
case Intersect(left, right) if left.isStreaming && right.isStreaming =>
312+
case Intersect(left, right, _) if left.isStreaming && right.isStreaming =>
313313
throwError("Intersect between two streaming DataFrames/Datasets is not supported")
314314

315315
case GroupingSets(_, _, child, _) if child.isStreaming =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
136136
OptimizeSubqueries) ::
137137
Batch("Replace Operators", fixedPoint,
138138
RewriteExcepAll,
139+
RewriteIntersectAll,
139140
ReplaceIntersectWithSemiJoin,
140141
ReplaceExceptWithFilter,
141142
ReplaceExceptWithAntiJoin,
@@ -1402,7 +1403,7 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
14021403
*/
14031404
object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
14041405
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1405-
case Intersect(left, right) =>
1406+
case Intersect(left, right, false) =>
14061407
assert(left.output.size == right.output.size)
14071408
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
14081409
Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
@@ -1488,6 +1489,84 @@ object RewriteExcepAll extends Rule[LogicalPlan] {
14881489
}
14891490
}
14901491

1492+
/**
1493+
* Replaces logical [[Intersect]] operator using a combination of Union, Aggregate
1494+
* and Generate operator.
1495+
*
1496+
* Input Query :
1497+
* {{{
1498+
* SELECT c1 FROM ut1 INTERSECT ALL SELECT c1 FROM ut2
1499+
* }}}
1500+
*
1501+
* Rewritten Query:
1502+
* {{{
1503+
* SELECT c1
1504+
* FROM (
1505+
* SELECT replicate_row(min_count, c1)
1506+
* FROM (
1507+
* SELECT c1, If (vcol1_cnt > vcol2_cnt, vcol2_cnt, vcol1_cnt) AS min_count
1508+
* FROM (
1509+
* SELECT c1, count(vcol1) as vcol1_cnt, count(vcol2) as vcol2_cnt
1510+
* FROM (
1511+
* SELECT true as vcol1, null as , c1 FROM ut1
1512+
* UNION ALL
1513+
* SELECT null as vcol1, true as vcol2, c1 FROM ut2
1514+
* ) AS union_all
1515+
* GROUP BY c1
1516+
* HAVING vcol1_cnt >= 1 AND vcol2_cnt >= 1
1517+
* )
1518+
* )
1519+
* )
1520+
* }}}
1521+
*/
1522+
object RewriteIntersectAll extends Rule[LogicalPlan] {
1523+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1524+
case Intersect(left, right, true) =>
1525+
assert(left.output.size == right.output.size)
1526+
1527+
val trueVcol1 = Alias(Literal(true), "vcol1")()
1528+
val nullVcol1 = Alias(Literal(null, BooleanType), "vcol1")()
1529+
1530+
val trueVcol2 = Alias(Literal(true), "vcol2")()
1531+
val nullVcol2 = Alias(Literal(null, BooleanType), "vcol2")()
1532+
1533+
// Add a projection on the top of left and right plans to project out
1534+
// the additional virtual columns.
1535+
val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left)
1536+
val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right)
1537+
1538+
val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols)
1539+
1540+
// Expressions to compute count and minimum of both the counts.
1541+
val vCol1AggrExpr =
1542+
Alias(AggregateExpression(Count(unionPlan.output(0)), Complete, false), "vcol1_count")()
1543+
val vCol2AggrExpr =
1544+
Alias(AggregateExpression(Count(unionPlan.output(1)), Complete, false), "vcol2_count")()
1545+
val ifExpression = Alias(If(
1546+
GreaterThan(vCol1AggrExpr.toAttribute, vCol2AggrExpr.toAttribute),
1547+
vCol2AggrExpr.toAttribute,
1548+
vCol1AggrExpr.toAttribute
1549+
), "min_count")()
1550+
1551+
val aggregatePlan = Aggregate(left.output,
1552+
Seq(vCol1AggrExpr, vCol2AggrExpr) ++ left.output, unionPlan)
1553+
val filterPlan = Filter(And(GreaterThanOrEqual(vCol1AggrExpr.toAttribute, Literal(1L)),
1554+
GreaterThanOrEqual(vCol2AggrExpr.toAttribute, Literal(1L))), aggregatePlan)
1555+
val projectMinPlan = Project(left.output ++ Seq(ifExpression), filterPlan)
1556+
1557+
// Apply the replicator to replicate rows based on min_count
1558+
val genRowPlan = Generate(
1559+
ReplicateRows(Seq(ifExpression.toAttribute) ++ left.output),
1560+
unrequiredChildIndex = Nil,
1561+
outer = false,
1562+
qualifier = None,
1563+
left.output,
1564+
projectMinPlan
1565+
)
1566+
Project(left.output, genRowPlan)
1567+
}
1568+
}
1569+
14911570
/**
14921571
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
14931572
* but only makes the grouping key bigger.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
533533
case SqlBaseParser.UNION =>
534534
Distinct(Union(left, right))
535535
case SqlBaseParser.INTERSECT if all =>
536-
throw new ParseException("INTERSECT ALL is not supported.", ctx)
536+
Intersect(left, right, isAll = true)
537537
case SqlBaseParser.INTERSECT =>
538538
Intersect(left, right)
539539
case SqlBaseParser.EXCEPT if all =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,12 @@ object SetOperation {
164164
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
165165
}
166166

167-
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
167+
case class Intersect(
168+
left: LogicalPlan,
169+
right: LogicalPlan,
170+
isAll: Boolean = false) extends SetOperation(left, right) {
171+
172+
override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" )
168173

169174
override def output: Seq[Attribute] =
170175
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows}
23+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows}
2424
import org.apache.spark.sql.catalyst.plans.PlanTest
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
27+
import org.apache.spark.sql.types.BooleanType
2728

2829
class SetOperationSuite extends PlanTest {
2930
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -166,4 +167,33 @@ class SetOperationSuite extends PlanTest {
166167
))
167168
comparePlans(expectedPlan, rewrittenPlan)
168169
}
170+
171+
test("INTERSECT ALL rewrite") {
172+
val input = Intersect(testRelation, testRelation2, isAll = true)
173+
val rewrittenPlan = RewriteIntersectAll(input)
174+
val leftRelation = testRelation
175+
.select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c)
176+
val rightRelation = testRelation2
177+
.select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f)
178+
val planFragment = leftRelation.union(rightRelation)
179+
.groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"),
180+
count('vcol2).as("vcol2_count"), 'a, 'b, 'c)
181+
.where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)),
182+
GreaterThanOrEqual('vcol2_count, Literal(1L))))
183+
.select('a, 'b, 'c,
184+
If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count"))
185+
.analyze
186+
val multiplerAttr = planFragment.output.last
187+
val output = planFragment.output.dropRight(1)
188+
val expectedPlan = Project(output,
189+
Generate(
190+
ReplicateRows(Seq(multiplerAttr) ++ output),
191+
Nil,
192+
false,
193+
None,
194+
output,
195+
planFragment
196+
))
197+
comparePlans(expectedPlan, rewrittenPlan)
198+
}
169199
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ class PlanParserSuite extends AnalysisTest {
7070
intercept("select * from a minus all select * from b", "MINUS ALL is not supported.")
7171
assertEqual("select * from a minus distinct select * from b", a.except(b))
7272
assertEqual("select * from a intersect select * from b", a.intersect(b))
73-
intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
7473
assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
7574
}
7675

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1934,6 +1934,23 @@ class Dataset[T] private[sql](
19341934
Intersect(planWithBarrier, other.planWithBarrier)
19351935
}
19361936

1937+
/**
1938+
* Returns a new Dataset containing rows only in both this Dataset and another Dataset while
1939+
* preserving the duplicates.
1940+
* This is equivalent to `INTERSECT ALL` in SQL.
1941+
*
1942+
* @note Equality checking is performed directly on the encoded representation of the data
1943+
* and thus is not affected by a custom `equals` function defined on `T`. Also as standard
1944+
* in SQL, this function resolves columns by position (not by name).
1945+
*
1946+
* @group typedrel
1947+
* @since 2.4.0
1948+
*/
1949+
def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator {
1950+
Intersect(logicalPlan, other.logicalPlan, isAll = true)
1951+
}
1952+
1953+
19371954
/**
19381955
* Returns a new Dataset containing rows in this Dataset but not in another Dataset.
19391956
* This is equivalent to `EXCEPT DISTINCT` in SQL.
@@ -1961,7 +1978,7 @@ class Dataset[T] private[sql](
19611978
* @since 2.4.0
19621979
*/
19631980
def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator {
1964-
Except(planWithBarrier, other.planWithBarrier, isAll = true)
1981+
Except(logicalPlan, other.logicalPlan, isAll = true)
19651982
}
19661983

19671984
/**

0 commit comments

Comments
 (0)