Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 280ff52

Browse files
committed
[SPARK-21977] SinglePartition optimizations break certain Streaming Stateful Aggregation requirements
## What changes were proposed in this pull request? This is a bit hard to explain as there are several issues here, I'll try my best. Here are the requirements: 1. A StructuredStreaming Source that can generate empty RDDs with 0 partitions 2. A StructuredStreaming query that uses the above source, performs a stateful aggregation (mapGroupsWithState, groupBy.count, ...), and coalesce's by 1 The crux of the problem is that when a dataset has a `coalesce(1)` call, it receives a `SinglePartition` partitioning scheme. This scheme satisfies most required distributions used for aggregations such as HashAggregateExec. This causes a world of problems: Symptom 1. If the input RDD has 0 partitions, the whole lineage will receive 0 partitions, nothing will be executed, the state store will not create any delta files. When this happens, the next trigger fails, because the StateStore fails to load the delta file for the previous trigger Symptom 2. Let's say that there was data. Then in this case, if you stop your stream, and change `coalesce(1)` with `coalesce(2)`, then restart your stream, your stream will fail, because `spark.sql.shuffle.partitions - 1` number of StateStores will fail to find its delta files. To fix the issues above, we must check that the partitioning of the child of a `StatefulOperator` satisfies: If the grouping expressions are empty: a) AllTuple distribution b) Single physical partition If the grouping expressions are non empty: a) Clustered distribution b) spark.sql.shuffle.partition # of partitions whether or not `coalesce(1)` exists in the plan, and whether or not the input RDD for the trigger has any data. Once you fix the above problem by adding an Exchange to the plan, you come across the following bug: If you call `coalesce(1).groupBy().count()` on a Streaming DataFrame, and if you have a trigger with no data, `StateStoreRestoreExec` doesn't return the prior state. However, for this specific aggregation, `HashAggregateExec` after the restore returns a (0, 0) row, since we're performing a count, and there is no data. Then this data gets stored in `StateStoreSaveExec` causing the previous counts to be overwritten and lost. ## How was this patch tested? Regression tests Author: Burak Yavuz <[email protected]> Closes apache#19196 from brkyvz/sa-0.
1 parent c6ff59a commit 280ff52

File tree

6 files changed

+395
-21
lines changed

6 files changed

+395
-21
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ import java.util.UUID
2121
import java.util.concurrent.atomic.AtomicInteger
2222

2323
import org.apache.spark.internal.Logging
24-
import org.apache.spark.sql.{SparkSession, Strategy}
24+
import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy}
2525
import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp
2626
import org.apache.spark.sql.catalyst.plans.logical._
27+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition}
2728
import org.apache.spark.sql.catalyst.rules.Rule
2829
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
30+
import org.apache.spark.sql.execution.exchange.ShuffleExchange
2931
import org.apache.spark.sql.streaming.OutputMode
3032

3133
/**
@@ -89,7 +91,7 @@ class IncrementalExecution(
8991
override def apply(plan: SparkPlan): SparkPlan = plan transform {
9092
case StateStoreSaveExec(keys, None, None, None,
9193
UnaryExecNode(agg,
92-
StateStoreRestoreExec(keys2, None, child))) =>
94+
StateStoreRestoreExec(_, None, child))) =>
9395
val aggStateInfo = nextStatefulOperationStateInfo
9496
StateStoreSaveExec(
9597
keys,
@@ -117,8 +119,34 @@ class IncrementalExecution(
117119
}
118120
}
119121

120-
override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
122+
override def preparations: Seq[Rule[SparkPlan]] =
123+
Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations
121124

122125
/** No need assert supported, as this check has already been done */
123126
override def assertSupported(): Unit = { }
124127
}
128+
129+
object EnsureStatefulOpPartitioning extends Rule[SparkPlan] {
130+
// Needs to be transformUp to avoid extra shuffles
131+
override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
132+
case so: StatefulOperator =>
133+
val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions
134+
val distributions = so.requiredChildDistribution
135+
val children = so.children.zip(distributions).map { case (child, reqDistribution) =>
136+
val expectedPartitioning = reqDistribution match {
137+
case AllTuples => SinglePartition
138+
case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions)
139+
case _ => throw new AnalysisException("Unexpected distribution expected for " +
140+
s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " +
141+
s"$reqDistribution.")
142+
}
143+
if (child.outputPartitioning.guarantees(expectedPartitioning) &&
144+
child.execute().getNumPartitions == expectedPartitioning.numPartitions) {
145+
child
146+
} else {
147+
ShuffleExchange(expectedPartitioning, child)
148+
}
149+
}
150+
so.withNewChildren(children)
151+
}
152+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ class StreamExecution(
829829
if (streamDeathCause != null) {
830830
throw streamDeathCause
831831
}
832+
if (!isActive) return
832833
awaitBatchLock.lock()
833834
try {
834835
noNewData = false

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.errors._
2828
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
3030
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
31-
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
31+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
3232
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
3333
import org.apache.spark.sql.execution._
3434
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -200,18 +200,35 @@ case class StateStoreRestoreExec(
200200
sqlContext.sessionState,
201201
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
202202
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
203-
iter.flatMap { row =>
204-
val key = getKey(row)
205-
val savedState = store.get(key)
206-
numOutputRows += 1
207-
row +: Option(savedState).toSeq
203+
val hasInput = iter.hasNext
204+
if (!hasInput && keyExpressions.isEmpty) {
205+
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
206+
// the `HashAggregateExec` will output a 0 value for the partial merge. We need to
207+
// restore the value, so that we don't overwrite our state with a 0 value, but rather
208+
// merge the 0 with existing state.
209+
store.iterator().map(_.value)
210+
} else {
211+
iter.flatMap { row =>
212+
val key = getKey(row)
213+
val savedState = store.get(key)
214+
numOutputRows += 1
215+
row +: Option(savedState).toSeq
216+
}
208217
}
209218
}
210219
}
211220

212221
override def output: Seq[Attribute] = child.output
213222

214223
override def outputPartitioning: Partitioning = child.outputPartitioning
224+
225+
override def requiredChildDistribution: Seq[Distribution] = {
226+
if (keyExpressions.isEmpty) {
227+
AllTuples :: Nil
228+
} else {
229+
ClusteredDistribution(keyExpressions) :: Nil
230+
}
231+
}
215232
}
216233

217234
/**
@@ -351,6 +368,14 @@ case class StateStoreSaveExec(
351368
override def output: Seq[Attribute] = child.output
352369

353370
override def outputPartitioning: Partitioning = child.outputPartitioning
371+
372+
override def requiredChildDistribution: Seq[Distribution] = {
373+
if (keyExpressions.isEmpty) {
374+
AllTuples :: Nil
375+
} else {
376+
ClusteredDistribution(keyExpressions) :: Nil
377+
}
378+
}
354379
}
355380

356381
/** Physical operator for executing streaming Deduplicate. */
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.streaming
19+
20+
import java.util.UUID
21+
22+
import org.apache.spark.rdd.RDD
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
25+
import org.apache.spark.sql.catalyst.expressions.Attribute
26+
import org.apache.spark.sql.catalyst.plans.physical._
27+
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode}
28+
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
29+
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo}
30+
import org.apache.spark.sql.test.SharedSQLContext
31+
32+
class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
33+
34+
import testImplicits._
35+
super.beforeAll()
36+
37+
private val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char")
38+
39+
testEnsureStatefulOpPartitioning(
40+
"ClusteredDistribution generates Exchange with HashPartitioning",
41+
baseDf.queryExecution.sparkPlan,
42+
requiredDistribution = keys => ClusteredDistribution(keys),
43+
expectedPartitioning =
44+
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
45+
expectShuffle = true)
46+
47+
testEnsureStatefulOpPartitioning(
48+
"ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning",
49+
baseDf.coalesce(1).queryExecution.sparkPlan,
50+
requiredDistribution = keys => ClusteredDistribution(keys),
51+
expectedPartitioning =
52+
keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions),
53+
expectShuffle = true)
54+
55+
testEnsureStatefulOpPartitioning(
56+
"AllTuples generates Exchange with SinglePartition",
57+
baseDf.queryExecution.sparkPlan,
58+
requiredDistribution = _ => AllTuples,
59+
expectedPartitioning = _ => SinglePartition,
60+
expectShuffle = true)
61+
62+
testEnsureStatefulOpPartitioning(
63+
"AllTuples with coalesce(1) doesn't need Exchange",
64+
baseDf.coalesce(1).queryExecution.sparkPlan,
65+
requiredDistribution = _ => AllTuples,
66+
expectedPartitioning = _ => SinglePartition,
67+
expectShuffle = false)
68+
69+
/**
70+
* For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
71+
* `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
72+
* ensure the expected partitioning.
73+
*/
74+
private def testEnsureStatefulOpPartitioning(
75+
testName: String,
76+
inputPlan: SparkPlan,
77+
requiredDistribution: Seq[Attribute] => Distribution,
78+
expectedPartitioning: Seq[Attribute] => Partitioning,
79+
expectShuffle: Boolean): Unit = {
80+
test(testName) {
81+
val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1)))
82+
val executed = executePlan(operator, OutputMode.Complete())
83+
if (expectShuffle) {
84+
val exchange = executed.children.find(_.isInstanceOf[Exchange])
85+
if (exchange.isEmpty) {
86+
fail(s"Was expecting an exchange but didn't get one in:\n$executed")
87+
}
88+
assert(exchange.get ===
89+
ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan),
90+
s"Exchange didn't have expected properties:\n${exchange.get}")
91+
} else {
92+
assert(!executed.children.exists(_.isInstanceOf[Exchange]),
93+
s"Unexpected exchange found in:\n$executed")
94+
}
95+
}
96+
}
97+
98+
/** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */
99+
private def executePlan(
100+
p: SparkPlan,
101+
outputMode: OutputMode = OutputMode.Append()): SparkPlan = {
102+
val execution = new IncrementalExecution(
103+
spark,
104+
null,
105+
OutputMode.Complete(),
106+
"chk",
107+
UUID.randomUUID(),
108+
0L,
109+
OffsetSeqMetadata()) {
110+
override lazy val sparkPlan: SparkPlan = p transform {
111+
case plan: SparkPlan =>
112+
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
113+
plan transformExpressions {
114+
case UnresolvedAttribute(Seq(u)) =>
115+
inputMap.getOrElse(u,
116+
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
117+
}
118+
}
119+
}
120+
execution.executedPlan
121+
}
122+
}
123+
124+
/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */
125+
case class TestStatefulOperator(
126+
child: SparkPlan,
127+
requiredDist: Distribution) extends UnaryExecNode with StatefulOperator {
128+
override def output: Seq[Attribute] = child.output
129+
override def doExecute(): RDD[InternalRow] = child.execute()
130+
override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil
131+
override def stateInfo: Option[StatefulOperatorStateInfo] = None
132+
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
167167
case class StartStream(
168168
trigger: Trigger = Trigger.ProcessingTime(0),
169169
triggerClock: Clock = new SystemClock,
170-
additionalConfs: Map[String, String] = Map.empty)
170+
additionalConfs: Map[String, String] = Map.empty,
171+
checkpointLocation: String = null)
171172
extends StreamAction
172173

173174
/** Advance the trigger clock's time manually. */
@@ -349,20 +350,22 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
349350
""".stripMargin)
350351
}
351352

352-
val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
353353
var manualClockExpectedTime = -1L
354+
val defaultCheckpointLocation =
355+
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
354356
try {
355357
startedTest.foreach { action =>
356358
logInfo(s"Processing test stream action: $action")
357359
action match {
358-
case StartStream(trigger, triggerClock, additionalConfs) =>
360+
case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) =>
359361
verify(currentStream == null, "stream already running")
360362
verify(triggerClock.isInstanceOf[SystemClock]
361363
|| triggerClock.isInstanceOf[StreamManualClock],
362364
"Use either SystemClock or StreamManualClock to start the stream")
363365
if (triggerClock.isInstanceOf[StreamManualClock]) {
364366
manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
365367
}
368+
val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
366369

367370
additionalConfs.foreach(pair => {
368371
val value =
@@ -479,7 +482,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
479482
verify(currentStream != null || lastStream != null,
480483
"cannot assert when no stream has been started")
481484
val streamToAssert = Option(currentStream).getOrElse(lastStream)
482-
verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
485+
try {
486+
verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
487+
} catch {
488+
case NonFatal(e) =>
489+
failTest(s"Assert on query failed: ${a.message}", e)
490+
}
483491

484492
case a: Assert =>
485493
val streamToAssert = Option(currentStream).getOrElse(lastStream)

0 commit comments

Comments
 (0)