Skip to content

Commit 5a7aa6f

Browse files
parthchandradongjoon-hyun
authored andcommitted
[SPARK-27100][SQL] Use Array instead of Seq in FilePartition to prevent StackOverflowError
## What changes were proposed in this pull request? ShuffleMapTask's partition field is a FilePartition and FilePartition's 'files' field is a Stream$cons which is essentially a linked list. It is therefore serialized recursively. If the number of files in each partition is, say, 10000 files, recursing into a linked list of length 10000 overflows the stack The problem is only in Bucketed partitions. The corresponding implementation for non Bucketed partitions uses a StreamBuffer. The proposed change applies the same for Bucketed partitions. ## How was this patch tested? Existing unit tests. Added new unit test. The unit test fails without the patch. Manual testing on dataset used to reproduce the problem. Closes apache#24865 from parthchandra/SPARK-27100. Lead-authored-by: Parth Chandra <[email protected]> Co-authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: DB Tsai <[email protected]>
1 parent 929d313 commit 5a7aa6f

File tree

4 files changed

+59
-9
lines changed

4 files changed

+59
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ case class FileSourceScanExec(
176176
metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq)
177177
}
178178

179-
@transient private lazy val selectedPartitions: Seq[PartitionDirectory] = {
179+
@transient private lazy val selectedPartitions: Array[PartitionDirectory] = {
180180
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
181181
val startTime = System.nanoTime()
182182
val ret = relation.location.listFiles(partitionFilters, dataFilters)
@@ -185,7 +185,7 @@ case class FileSourceScanExec(
185185
(System.nanoTime() - startTime) + optimizerMetadataTimeNs)
186186
driverMetrics("metadataTime") = timeTakenMs
187187
ret
188-
}
188+
}.toArray
189189

190190
/**
191191
* [[partitionFilters]] can contain subqueries whose results are available only at runtime so
@@ -377,7 +377,7 @@ case class FileSourceScanExec(
377377
private def createBucketedReadRDD(
378378
bucketSpec: BucketSpec,
379379
readFile: (PartitionedFile) => Iterator[InternalRow],
380-
selectedPartitions: Seq[PartitionDirectory],
380+
selectedPartitions: Array[PartitionDirectory],
381381
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
382382
logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
383383
val filesGroupedToBuckets =
@@ -401,7 +401,7 @@ case class FileSourceScanExec(
401401
}
402402

403403
val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
404-
FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil))
404+
FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
405405
}
406406

407407
new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
@@ -417,7 +417,7 @@ case class FileSourceScanExec(
417417
*/
418418
private def createNonBucketedReadRDD(
419419
readFile: (PartitionedFile) => Iterator[InternalRow],
420-
selectedPartitions: Seq[PartitionDirectory],
420+
selectedPartitions: Array[PartitionDirectory],
421421
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
422422
val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
423423
val maxSplitBytes =
@@ -440,7 +440,7 @@ case class FileSourceScanExec(
440440
partitionValues = partition.values
441441
)
442442
}
443-
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
443+
}.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
444444

445445
val partitions =
446446
FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.sources.v2.reader.InputPartition
2828
* A collection of file blocks that should be read as a single task
2929
* (possibly from multiple partitioned directories).
3030
*/
31-
case class FilePartition(index: Int, files: Seq[PartitionedFile])
31+
case class FilePartition(index: Int, files: Array[PartitionedFile])
3232
extends Partition with InputPartition {
3333
override def preferredLocations(): Array[String] = {
3434
// Computes total number of bytes can be retrieved from each host.
@@ -62,7 +62,7 @@ object FilePartition extends Logging {
6262
def closePartition(): Unit = {
6363
if (currentFiles.nonEmpty) {
6464
// Copy to a new Array.
65-
val newPartition = FilePartition(partitions.size, currentFiles.toArray.toSeq)
65+
val newPartition = FilePartition(partitions.size, currentFiles.toArray)
6666
partitions += newPartition
6767
}
6868
currentFiles.clear()

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
279279
}
280280

281281
test("Locality support for FileScanRDD") {
282-
val partition = FilePartition(0, Seq(
282+
val partition = FilePartition(0, Array(
283283
PartitionedFile(InternalRow.empty, "fakePath0", 0, 10, Array("host0", "host1")),
284284
PartitionedFile(InternalRow.empty, "fakePath0", 10, 20, Array("host1", "host2")),
285285
PartitionedFile(InternalRow.empty, "fakePath1", 0, 5, Array("host3")),

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,4 +735,54 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
735735
df1.groupBy("j").agg(max("k")))
736736
}
737737
}
738+
739+
// A test with a partition where the number of files in the partition is
740+
// large. tests for the condition where the serialization of such a task may result in a stack
741+
// overflow if the files list is stored in a recursive data structure
742+
// This test is ignored because it takes long to run (~3 min)
743+
ignore("SPARK-27100 stack overflow: read data with large partitions") {
744+
val nCount = 20000
745+
// reshuffle data so that many small files are created
746+
val nShufflePartitions = 10000
747+
// and with one table partition, should result in 10000 files in one partition
748+
val nPartitions = 1
749+
val nBuckets = 2
750+
val dfPartitioned = (0 until nCount)
751+
.map(i => (i % nPartitions, i % nBuckets, i.toString)).toDF("i", "j", "k")
752+
753+
// non-bucketed tables. This part succeeds without the fix for SPARK-27100
754+
try {
755+
withTable("non_bucketed_table") {
756+
dfPartitioned.repartition(nShufflePartitions)
757+
.write
758+
.format("parquet")
759+
.partitionBy("i")
760+
.saveAsTable("non_bucketed_table")
761+
762+
val table = spark.table("non_bucketed_table")
763+
val nValues = table.select("j", "k").count()
764+
assert(nValues == nCount)
765+
}
766+
} catch {
767+
case e: Exception => fail("Failed due to exception: " + e)
768+
}
769+
// bucketed tables. This fails without the fix for SPARK-27100
770+
try {
771+
withTable("bucketed_table") {
772+
dfPartitioned.repartition(nShufflePartitions)
773+
.write
774+
.format("parquet")
775+
.partitionBy("i")
776+
.bucketBy(nBuckets, "j")
777+
.saveAsTable("bucketed_table")
778+
779+
val table = spark.table("bucketed_table")
780+
val nValues = table.select("j", "k").count()
781+
assert(nValues == nCount)
782+
}
783+
} catch {
784+
case e: Exception => fail("Failed due to exception: " + e)
785+
}
786+
}
787+
738788
}

0 commit comments

Comments
 (0)