Skip to content

Commit 15298b9

Browse files
committed
[SPARK-23827][SS] StreamingJoinExec should ensure that input data is partitioned into specific number of partitions
## What changes were proposed in this pull request? Currently, the requiredChildDistribution does not specify the partitions. This can cause the weird corner cases where the child's distribution is `SinglePartition` which satisfies the required distribution of `ClusterDistribution(no-num-partition-requirement)`, thus eliminating the shuffle needed to repartition input data into the required number of partitions (i.e. same as state stores). That can lead to "file not found" errors on the state store delta files as the micro-batch-with-no-shuffle will not run certain tasks and therefore not generate the expected state store delta files. This PR adds the required constraint on the number of partitions. ## How was this patch tested? Modified test harness to always check that ANY stateful operator should have a constraint on the number of partitions. As part of that, the existing opt-in checks on child output partitioning were removed, as they are redundant. Author: Tathagata Das <[email protected]> Closes apache#20941 from tdas/SPARK-23827.
1 parent ae91720 commit 15298b9

File tree

7 files changed

+25
-65
lines changed

7 files changed

+25
-65
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class IncrementalExecution(
6262
StreamingDeduplicationStrategy :: Nil
6363
}
6464

65-
private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
65+
private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
6666
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
6767
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
6868

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec(
167167
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)
168168

169169
override def requiredChildDistribution: Seq[Distribution] =
170-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
170+
ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
171+
ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
171172

172173
override def output: Seq[Attribute] = joinType match {
173174
case _: InnerLike => left.output ++ right.output

sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplic
2525
import org.apache.spark.sql.execution.streaming.state.StateStore
2626
import org.apache.spark.sql.functions._
2727

28-
class DeduplicateSuite extends StateStoreMetricsTest
29-
with BeforeAndAfterAll
30-
with StatefulOperatorTest {
28+
class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
3129

3230
import testImplicits._
3331

@@ -44,8 +42,6 @@ class DeduplicateSuite extends StateStoreMetricsTest
4442
AddData(inputData, "a"),
4543
CheckLastBatch("a"),
4644
assertNumStateRows(total = 1, updated = 1),
47-
AssertOnQuery(sq =>
48-
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))),
4945
AddData(inputData, "a"),
5046
CheckLastBatch(),
5147
assertNumStateRows(total = 1, updated = 0),
@@ -63,8 +59,6 @@ class DeduplicateSuite extends StateStoreMetricsTest
6359
AddData(inputData, "a" -> 1),
6460
CheckLastBatch("a" -> 1),
6561
assertNumStateRows(total = 1, updated = 1),
66-
AssertOnQuery(sq =>
67-
checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))),
6862
AddData(inputData, "a" -> 2), // Dropped
6963
CheckLastBatch(),
7064
assertNumStateRows(total = 1, updated = 0),

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ case class RunningCount(count: Long)
4242
case class Result(key: Long, count: Int)
4343

4444
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
45-
with BeforeAndAfterAll
46-
with StatefulOperatorTest {
45+
with BeforeAndAfterAll {
4746

4847
import testImplicits._
4948
import GroupStateImpl._
@@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
618617
AddData(inputData, "a"),
619618
CheckLastBatch(("a", "1")),
620619
assertNumStateRows(total = 1, updated = 1),
621-
AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec](
622-
sq, Seq("value"))),
623620
AddData(inputData, "a", "b"),
624621
CheckLastBatch(("a", "2"), ("b", "1")),
625622
assertNumStateRows(total = 2, updated = 2),

sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala

Lines changed: 0 additions & 49 deletions
This file was deleted.

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.SparkEnv
3737
import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
3838
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
3939
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
40+
import org.apache.spark.sql.catalyst.plans.physical.AllTuples
4041
import org.apache.spark.sql.catalyst.util._
4142
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
4243
import org.apache.spark.sql.execution.streaming._
@@ -444,6 +445,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
444445
}
445446
}
446447

448+
val lastExecution = currentStream.lastExecution
449+
if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) {
450+
// Verify if stateful operators have correct metadata and distribution
451+
// This can often catch hard to debug errors when developing stateful operators
452+
lastExecution.executedPlan.collect { case s: StatefulOperator => s }.foreach { s =>
453+
assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores))
454+
s.requiredChildDistribution.foreach { d =>
455+
withClue(s"$s specifies incorrect # partitions in requiredChildDistribution $d") {
456+
assert(d.requiredNumPartitions.isDefined)
457+
assert(d.requiredNumPartitions.get >= 1)
458+
if (d != AllTuples) {
459+
assert(d.requiredNumPartitions.get == s.stateInfo.get.numPartitions)
460+
}
461+
}
462+
}
463+
}
464+
}
465+
447466
val (latestBatchData, allData) = sink match {
448467
case s: MemorySink => (s.latestBatchData, s.allData)
449468
case s: MemorySinkV2 => (s.latestBatchData, s.allData)

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ object FailureSingleton {
4444
}
4545

4646
class StreamingAggregationSuite extends StateStoreMetricsTest
47-
with BeforeAndAfterAll with Assertions with StatefulOperatorTest {
47+
with BeforeAndAfterAll with Assertions {
4848

4949
override def afterAll(): Unit = {
5050
super.afterAll()
@@ -281,8 +281,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
281281
AddData(inputData, 0L, 5L, 5L, 10L),
282282
AdvanceManualClock(10 * 1000),
283283
CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
284-
AssertOnQuery(sq =>
285-
checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))),
286284

287285
// advance clock to 20 seconds, should retain keys >= 10
288286
AddData(inputData, 15L, 15L, 20L),

0 commit comments

Comments
 (0)