Skip to content

Commit 10f1f19

Browse files
dilipbiswalgatorsmile
authored andcommitted
[SPARK-21274][SQL] Implement EXCEPT ALL clause.
## What changes were proposed in this pull request? Implements EXCEPT ALL clause through query rewrites using existing operators in Spark. In this PR, an internal UDTF (replicate_rows) is added to aid in preserving duplicate rows. Please refer to [Link](https://drive.google.com/open?id=1nyW0T0b_ajUduQoPgZLAsyHK8s3_dko3ulQuxaLpUXE) for the design. **Note** This proposed UDTF is kept as a internal function that is purely used to aid with this particular rewrite to give us flexibility to change to a more generalized UDTF in future. Input Query ``` SQL SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2 ``` Rewritten Query ```SQL SELECT c1 FROM ( SELECT replicate_rows(sum_val, c1) FROM ( SELECT c1, sum_val FROM ( SELECT c1, sum(vcol) AS sum_val FROM ( SELECT 1L as vcol, c1 FROM ut1 UNION ALL SELECT -1L as vcol, c1 FROM ut2 ) AS union_all GROUP BY union_all.c1 ) WHERE sum_val > 0 ) ) ``` ## How was this patch tested? Added test cases in SQLQueryTestSuite, DataFrameSuite and SetOperationSuite Author: Dilip Biswal <[email protected]> Closes apache#21857 from dilipbiswal/dkb_except_all_final.
1 parent 5828f41 commit 10f1f19

File tree

17 files changed

+708
-19
lines changed

17 files changed

+708
-19
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,31 @@ def explain(self, extended=False):
293293
else:
294294
print(self._jdf.queryExecution().simpleString())
295295

296+
@since(2.4)
297+
def exceptAll(self, other):
298+
"""Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but
299+
not in another :class:`DataFrame` while preserving duplicates.
300+
301+
This is equivalent to `EXCEPT ALL` in SQL.
302+
303+
>>> df1 = spark.createDataFrame(
304+
... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"])
305+
>>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"])
306+
307+
>>> df1.exceptAll(df2).show()
308+
+---+---+
309+
| C1| C2|
310+
+---+---+
311+
| a| 1|
312+
| a| 1|
313+
| a| 2|
314+
| c| 4|
315+
+---+---+
316+
317+
Also as standard in SQL, this function resolves columns by position (not by name).
318+
"""
319+
return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)
320+
296321
@since(1.3)
297322
def isLocal(self):
298323
"""Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -916,9 +916,8 @@ class Analyzer(
916916
j.copy(right = dedupRight(left, right))
917917
case i @ Intersect(left, right) if !i.duplicateResolved =>
918918
i.copy(right = dedupRight(left, right))
919-
case i @ Except(left, right) if !i.duplicateResolved =>
920-
i.copy(right = dedupRight(left, right))
921-
919+
case e @ Except(left, right, _) if !e.duplicateResolved =>
920+
e.copy(right = dedupRight(left, right))
922921
// When resolve `SortOrder`s in Sort based on child, don't report errors as
923922
// we still have chance to resolve it based on its descendants
924923
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,17 @@ object TypeCoercion {
319319
object WidenSetOperationTypes extends Rule[LogicalPlan] {
320320

321321
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
322-
case s @ SetOperation(left, right) if s.childrenResolved &&
323-
left.output.length == right.output.length && !s.resolved =>
322+
case s @ Except(left, right, isAll) if s.childrenResolved &&
323+
left.output.length == right.output.length && !s.resolved =>
324324
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
325325
assert(newChildren.length == 2)
326-
s.makeCopy(Array(newChildren.head, newChildren.last))
326+
Except(newChildren.head, newChildren.last, isAll)
327+
328+
case s @ Intersect(left, right) if s.childrenResolved &&
329+
left.output.length == right.output.length && !s.resolved =>
330+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
331+
assert(newChildren.length == 2)
332+
Intersect(newChildren.head, newChildren.last)
327333

328334
case s: Union if s.childrenResolved &&
329335
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ object UnsupportedOperationChecker {
306306
case u: Union if u.children.map(_.isStreaming).distinct.size == 2 =>
307307
throwError("Union between streaming and batch DataFrames/Datasets is not supported")
308308

309-
case Except(left, right) if right.isStreaming =>
309+
case Except(left, right, _) if right.isStreaming =>
310310
throwError("Except on a streaming DataFrame/Dataset on the right is not supported")
311311

312312
case Intersect(left, right) if left.isStreaming && right.isStreaming =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,32 @@ case class Stack(children: Seq[Expression]) extends Generator {
223223
}
224224
}
225225

226+
/**
227+
* Replicate the row N times. N is specified as the first argument to the function.
228+
* This is an internal function solely used by optimizer to rewrite EXCEPT ALL AND
229+
* INTERSECT ALL queries.
230+
*/
231+
case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback {
232+
private lazy val numColumns = children.length - 1 // remove the multiplier value from output.
233+
234+
override def elementSchema: StructType =
235+
StructType(children.tail.zipWithIndex.map {
236+
case (e, index) => StructField(s"col$index", e.dataType)
237+
})
238+
239+
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
240+
val numRows = children.head.eval(input).asInstanceOf[Long]
241+
val values = children.tail.map(_.eval(input)).toArray
242+
Range.Long(0, numRows, 1).map { _ =>
243+
val fields = new Array[Any](numColumns)
244+
for (col <- 0 until numColumns) {
245+
fields.update(col, values(col))
246+
}
247+
InternalRow(fields: _*)
248+
}
249+
}
250+
}
251+
226252
/**
227253
* Wrapper around another generator to specify outer behavior. This is used to implement functions
228254
* such as explode_outer. This expression gets replaced during analysis.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
135135
Batch("Subquery", Once,
136136
OptimizeSubqueries) ::
137137
Batch("Replace Operators", fixedPoint,
138+
RewriteExcepAll,
138139
ReplaceIntersectWithSemiJoin,
139140
ReplaceExceptWithFilter,
140141
ReplaceExceptWithAntiJoin,
@@ -1422,13 +1423,71 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
14221423
*/
14231424
object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
14241425
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1425-
case Except(left, right) =>
1426+
case Except(left, right, false) =>
14261427
assert(left.output.size == right.output.size)
14271428
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
14281429
Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And)))
14291430
}
14301431
}
14311432

1433+
/**
1434+
* Replaces logical [[Except]] operator using a combination of Union, Aggregate
1435+
* and Generate operator.
1436+
*
1437+
* Input Query :
1438+
* {{{
1439+
* SELECT c1 FROM ut1 EXCEPT ALL SELECT c1 FROM ut2
1440+
* }}}
1441+
*
1442+
* Rewritten Query:
1443+
* {{{
1444+
* SELECT c1
1445+
* FROM (
1446+
* SELECT replicate_rows(sum_val, c1)
1447+
* FROM (
1448+
* SELECT c1, sum_val
1449+
* FROM (
1450+
* SELECT c1, sum(vcol) AS sum_val
1451+
* FROM (
1452+
* SELECT 1L as vcol, c1 FROM ut1
1453+
* UNION ALL
1454+
* SELECT -1L as vcol, c1 FROM ut2
1455+
* ) AS union_all
1456+
* GROUP BY union_all.c1
1457+
* )
1458+
* WHERE sum_val > 0
1459+
* )
1460+
* )
1461+
* }}}
1462+
*/
1463+
1464+
object RewriteExcepAll extends Rule[LogicalPlan] {
1465+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1466+
case Except(left, right, true) =>
1467+
assert(left.output.size == right.output.size)
1468+
1469+
val newColumnLeft = Alias(Literal(1L), "vcol")()
1470+
val newColumnRight = Alias(Literal(-1L), "vcol")()
1471+
val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left)
1472+
val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right)
1473+
val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan)
1474+
val aggSumCol =
1475+
Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")()
1476+
val aggOutputColumns = left.output ++ Seq(aggSumCol)
1477+
val aggregatePlan = Aggregate(left.output, aggOutputColumns, unionPlan)
1478+
val filteredAggPlan = Filter(GreaterThan(aggSumCol.toAttribute, Literal(0L)), aggregatePlan)
1479+
val genRowPlan = Generate(
1480+
ReplicateRows(Seq(aggSumCol.toAttribute) ++ left.output),
1481+
unrequiredChildIndex = Nil,
1482+
outer = false,
1483+
qualifier = None,
1484+
left.output,
1485+
filteredAggPlan
1486+
)
1487+
Project(left.output, genRowPlan)
1488+
}
1489+
}
1490+
14321491
/**
14331492
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
14341493
* but only makes the grouping key bigger.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
4646
}
4747

4848
plan.transform {
49-
case e @ Except(left, right) if isEligible(left, right) =>
49+
case e @ Except(left, right, false) if isEligible(left, right) =>
5050
val newCondition = transformCondition(left, skipProject(right))
5151
newCondition.map { c =>
5252
Distinct(Filter(Not(c), left))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
537537
case SqlBaseParser.INTERSECT =>
538538
Intersect(left, right)
539539
case SqlBaseParser.EXCEPT if all =>
540-
throw new ParseException("EXCEPT ALL is not supported.", ctx)
540+
Except(left, right, isAll = true)
541541
case SqlBaseParser.EXCEPT =>
542542
Except(left, right)
543543
case SqlBaseParser.SETMINUS if all =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
183183
}
184184
}
185185

186-
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
187-
186+
case class Except(
187+
left: LogicalPlan,
188+
right: LogicalPlan,
189+
isAll: Boolean = false) extends SetOperation(left, right) {
190+
override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" )
188191
/** We don't use right.output because those rows get excluded from the set. */
189192
override def output: Seq[Attribute] = left.output
190193

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.Literal
23+
import org.apache.spark.sql.catalyst.expressions.{Alias, GreaterThan, Literal, ReplicateRows}
2424
import org.apache.spark.sql.catalyst.plans.PlanTest
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
@@ -144,4 +144,26 @@ class SetOperationSuite extends PlanTest {
144144
Distinct(Union(query3 :: query4 :: Nil))).analyze
145145
comparePlans(distinctUnionCorrectAnswer2, optimized2)
146146
}
147+
148+
test("EXCEPT ALL rewrite") {
149+
val input = Except(testRelation, testRelation2, isAll = true)
150+
val rewrittenPlan = RewriteExcepAll(input)
151+
152+
val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c)
153+
.union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f))
154+
.groupBy('a, 'b, 'c)('a, 'b, 'c, sum('vcol).as("sum"))
155+
.where(GreaterThan('sum, Literal(0L))).analyze
156+
val multiplerAttr = planFragment.output.last
157+
val output = planFragment.output.dropRight(1)
158+
val expectedPlan = Project(output,
159+
Generate(
160+
ReplicateRows(Seq(multiplerAttr) ++ output),
161+
Nil,
162+
false,
163+
None,
164+
output,
165+
planFragment
166+
))
167+
comparePlans(expectedPlan, rewrittenPlan)
168+
}
147169
}

0 commit comments

Comments
 (0)