Skip to content

Commit f61d599

Browse files
LuciferYangNgone51
authored andcommitted
[SPARK-36242][CORE] Ensure spill file closed before set success = true in ExternalSorter.spillMemoryIteratorToDisk method
### What changes were proposed in this pull request? The main change of this pr is move `writer.close()` before `success = true` to ensure spill file closed before set `success = true` in `ExternalSorter.spillMemoryIteratorToDisk` method. ### Why are the changes needed? Avoid setting `success = true` first and then failure of close spill file ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass the Jenkins or GitHub Action - Add a new Test case to check `The spill file should not exists if writer close fails` Closes apache#33460 from LuciferYang/external-sorter-spill-close. Authored-by: yangjie01 <yangjie01@baidu.com> Signed-off-by: yi.wu <yi.wu@databricks.com>
1 parent 3ff8c9f commit f61d599

File tree

2 files changed

+149
-3
lines changed

2 files changed

+149
-3
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,13 @@ private[spark] class ExternalSorter[K, V, C](
313313
}
314314
if (objectsWritten > 0) {
315315
flush()
316+
writer.close()
316317
} else {
317318
writer.revertPartialWritesAndClose()
318319
}
319320
success = true
320321
} finally {
321-
if (success) {
322-
writer.close()
323-
} else {
322+
if (!success) {
324323
// This code path only happens if an exception was thrown above before we set success;
325324
// close our stuff and let the exception be thrown further
326325
writer.revertPartialWritesAndClose()
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
import java.io.{File, IOException}
21+
import java.util.UUID
22+
23+
import scala.collection.mutable.ArrayBuffer
24+
25+
import org.mockito.ArgumentMatchers.{any, anyInt}
26+
import org.mockito.Mockito.{mock, when}
27+
import org.mockito.invocation.InvocationOnMock
28+
import org.scalatest.BeforeAndAfterEach
29+
30+
import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite, TaskContext}
31+
import org.apache.spark.executor.ShuffleWriteMetrics
32+
import org.apache.spark.internal.config
33+
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
34+
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance, SerializerManager}
35+
import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockManager, DiskBlockObjectWriter, TempShuffleBlockId}
36+
import org.apache.spark.util.{Utils => UUtils}
37+
38+
class ExternalSorterSpillSuite extends SparkFunSuite with BeforeAndAfterEach {
39+
40+
private val spillFilesCreated = ArrayBuffer.empty[File]
41+
42+
private var tempDir: File = _
43+
private var conf: SparkConf = _
44+
private var taskMemoryManager: TaskMemoryManager = _
45+
46+
private var blockManager: BlockManager = _
47+
private var diskBlockManager: DiskBlockManager = _
48+
private var taskContext: TaskContext = _
49+
50+
override protected def beforeEach(): Unit = {
51+
tempDir = UUtils.createTempDir(null, "test")
52+
spillFilesCreated.clear()
53+
54+
val env: SparkEnv = mock(classOf[SparkEnv])
55+
SparkEnv.set(env)
56+
57+
conf = new SparkConf()
58+
when(SparkEnv.get.conf).thenReturn(conf)
59+
60+
val serializer = new KryoSerializer(conf)
61+
when(SparkEnv.get.serializer).thenReturn(serializer)
62+
63+
blockManager = mock(classOf[BlockManager])
64+
when(SparkEnv.get.blockManager).thenReturn(blockManager)
65+
66+
val manager = new SerializerManager(serializer, conf)
67+
when(blockManager.serializerManager).thenReturn(manager)
68+
69+
diskBlockManager = mock(classOf[DiskBlockManager])
70+
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
71+
72+
taskContext = mock(classOf[TaskContext])
73+
val memoryManager = new TestMemoryManager(conf)
74+
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
75+
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
76+
77+
when(diskBlockManager.createTempShuffleBlock())
78+
.thenAnswer((_: InvocationOnMock) => {
79+
val blockId = TempShuffleBlockId(UUID.randomUUID)
80+
val file = File.createTempFile("spillFile", ".spill", tempDir)
81+
spillFilesCreated += file
82+
(blockId, file)
83+
})
84+
}
85+
86+
override protected def afterEach(): Unit = {
87+
UUtils.deleteRecursively(tempDir)
88+
SparkEnv.set(null)
89+
90+
val leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory
91+
if (leakedMemory != 0) {
92+
fail("Test leaked " + leakedMemory + " bytes of managed memory")
93+
}
94+
}
95+
96+
test("SPARK-36242 Spill File should not exists if writer close fails") {
97+
// Prepare the data and ensure that the amount of data let the `spill()` method
98+
// to enter the `objectsWritten > 0` branch
99+
val writeSize = conf.get(config.SHUFFLE_SPILL_BATCH_SIZE) + 1
100+
val dataBuffer = new PartitionedPairBuffer[Int, Int]
101+
(0 until writeSize.toInt).foreach(i => dataBuffer.insert(0, 0, i))
102+
103+
val externalSorter = new TestExternalSorter[Int, Int, Int](taskContext)
104+
105+
// Mock the answer of `blockManager.getDiskWriter` and let the `close()` method of
106+
// `DiskBlockObjectWriter` throw IOException.
107+
val errorMessage = "Spill file close failed"
108+
when(blockManager.getDiskWriter(
109+
any(classOf[BlockId]),
110+
any(classOf[File]),
111+
any(classOf[SerializerInstance]),
112+
anyInt(),
113+
any(classOf[ShuffleWriteMetrics])
114+
)).thenAnswer((invocation: InvocationOnMock) => {
115+
val args = invocation.getArguments
116+
new DiskBlockObjectWriter(
117+
args(1).asInstanceOf[File],
118+
blockManager.serializerManager,
119+
args(2).asInstanceOf[SerializerInstance],
120+
args(3).asInstanceOf[Int],
121+
false,
122+
args(4).asInstanceOf[ShuffleWriteMetrics],
123+
args(0).asInstanceOf[BlockId]
124+
) {
125+
override def close(): Unit = throw new IOException(errorMessage)
126+
}
127+
})
128+
129+
val ioe = intercept[IOException] {
130+
externalSorter.spill(dataBuffer)
131+
}
132+
133+
ioe.getMessage.equals(errorMessage)
134+
// The `TempShuffleBlock` create by diskBlockManager
135+
// will remain before SPARK-36242
136+
assert(!spillFilesCreated(0).exists())
137+
}
138+
}
139+
140+
/**
141+
* `TestExternalSorter` used to expand the access scope of the spill method.
142+
*/
143+
private[this] class TestExternalSorter[K, V, C](context: TaskContext)
144+
extends ExternalSorter[K, V, C](context) {
145+
override def spill(collection: WritablePartitionedPairCollection[K, C]): Unit =
146+
super.spill(collection)
147+
}

0 commit comments

Comments
 (0)