Skip to content

Commit 770add8

Browse files
committed
[SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task
## What changes were proposed in this pull request? A structured streaming query with a streaming aggregation can throw the following error in rare cases.  ``` java.lang.IllegalStateException: Cannot commit after already committed or aborted at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643) at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359) at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336) ``` This can happen when the following conditions are accidentally hit.  - Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.).  - Query running in `update}` mode - After the shuffle, a partition has exactly 128 records.  This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`. ## How was this patch tested? Added unit test that I have confirm will fail without the fix. Author: Tathagata Das <[email protected]> Closes apache#21124 from tdas/SPARK-23004.
1 parent 448d248 commit 770add8

File tree

2 files changed

+44
-21
lines changed

2 files changed

+44
-21
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -340,37 +340,35 @@ case class StateStoreSaveExec(
340340
// Update and output modified rows from the StateStore.
341341
case Some(Update) =>
342342

343-
val updatesStartTimeNs = System.nanoTime
344-
345-
new Iterator[InternalRow] {
346-
343+
new NextIterator[InternalRow] {
347344
// Filter late date using watermark if specified
348345
private[this] val baseIterator = watermarkPredicateForData match {
349346
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
350347
case None => iter
351348
}
349+
private val updatesStartTimeNs = System.nanoTime
352350

353-
override def hasNext: Boolean = {
354-
if (!baseIterator.hasNext) {
355-
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
356-
357-
// Remove old aggregates if watermark specified
358-
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
359-
commitTimeMs += timeTakenMs { store.commit() }
360-
setStoreMetrics(store)
361-
false
351+
override protected def getNext(): InternalRow = {
352+
if (baseIterator.hasNext) {
353+
val row = baseIterator.next().asInstanceOf[UnsafeRow]
354+
val key = getKey(row)
355+
store.put(key, row)
356+
numOutputRows += 1
357+
numUpdatedStateRows += 1
358+
row
362359
} else {
363-
true
360+
finished = true
361+
null
364362
}
365363
}
366364

367-
override def next(): InternalRow = {
368-
val row = baseIterator.next().asInstanceOf[UnsafeRow]
369-
val key = getKey(row)
370-
store.put(key, row)
371-
numOutputRows += 1
372-
numUpdatedStateRows += 1
373-
row
365+
override protected def close(): Unit = {
366+
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
367+
368+
// Remove old aggregates if watermark specified
369+
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
370+
commitTimeMs += timeTakenMs { store.commit() }
371+
setStoreMetrics(store)
374372
}
375373
}
376374

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
536536
)
537537
}
538538

539+
test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") {
540+
// See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
541+
// by ensuring the following.
542+
// - A streaming query with a streaming aggregation.
543+
// - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate.
544+
// - Post shuffle partition has exactly 128 records (i.e. the threshold at which
545+
// ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a
546+
// micro-batch with 128 records that shuffle to a single partition.
547+
// This test throws the exact error reported in SPARK-23004 without the corresponding fix.
548+
withSQLConf("spark.sql.shuffle.partitions" -> "1") {
549+
val input = MemoryStream[Int]
550+
val df = input.toDF().toDF("value")
551+
.selectExpr("value as group", "value")
552+
.groupBy("group")
553+
.agg(collect_list("value"))
554+
testStream(df, outputMode = OutputMode.Update)(
555+
AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
556+
AssertOnQuery { q =>
557+
q.processAllAvailable()
558+
true
559+
}
560+
)
561+
}
562+
}
563+
539564
/** Add blocks of data to the `BlockRDDBackedSource`. */
540565
case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData {
541566
override def addData(query: Option[StreamExecution]): (Source, Offset) = {

0 commit comments

Comments
 (0)