Skip to content

Commit 851362e

Browse files
committed
Add unit tests
1 parent af689fc commit 851362e

File tree

2 files changed

+167
-20
lines changed

2 files changed

+167
-20
lines changed

core/src/test/scala/org/apache/spark/storage/FallbackStorageSuite.scala

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ import java.nio.file.Files
2222
import scala.concurrent.duration._
2323
import scala.util.Random
2424

25+
import io.netty.buffer.ByteBuf
2526
import org.apache.hadoop.conf.Configuration
26-
import org.apache.hadoop.fs.{FSDataInputStream, LocalFileSystem, Path, PositionedReadable, Seekable}
27+
import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, LocalFileSystem, Path, PositionedReadable, Seekable}
2728
import org.mockito.{ArgumentMatchers => mc}
28-
import org.mockito.Mockito.{mock, never, verify, when}
29+
import org.mockito.Mockito.{mock, never, spy, times, verify, when}
2930
import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
3031

3132
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TestUtils}
@@ -110,7 +111,9 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext {
110111
intercept[java.io.EOFException] {
111112
FallbackStorage.read(conf, ShuffleBlockId(1, 1L, 0))
112113
}
113-
FallbackStorage.read(conf, ShuffleBlockId(1, 2L, 0))
114+
val readResult = FallbackStorage.read(conf, ShuffleBlockId(1, 2L, 0))
115+
assert(readResult.isInstanceOf[FileSystemSegmentManagedBuffer])
116+
readResult.createInputStream().close()
114117
}
115118

116119
test("SPARK-39200: fallback storage APIs - readFully") {
@@ -155,9 +158,49 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext {
155158
assert(fallbackStorage.exists(1, ShuffleDataBlockId(1, 2L, NOOP_REDUCE_ID).name))
156159

157160
val readResult = FallbackStorage.read(conf, ShuffleBlockId(1, 2L, 0))
161+
assert(readResult.isInstanceOf[FileSystemSegmentManagedBuffer])
158162
assert(readResult.nioByteBuffer().array().sameElements(content))
159163
}
160164

165+
test("SPARK-55469: FileSystemSegmentManagedBuffer reads block data lazily") {
166+
withTempDir { dir =>
167+
val fs = FileSystem.getLocal(new Configuration())
168+
val file = new Path(dir.getAbsolutePath, "file")
169+
val data = Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
170+
tryWithResource(fs.create(file)) { os => os.write(data) }
171+
172+
Seq((0, 4), (1, 2), (4, 4), (7, 2), (8, 0)).foreach { case (offset, length) =>
173+
val clue = s"offset: $offset, length: $length"
174+
175+
// creating the managed buffer does not open the file
176+
val mfs = spy(fs)
177+
val buf = new FileSystemSegmentManagedBuffer(mfs, file, offset, length)
178+
verify(mfs, never()).open(mc.any[Path]())
179+
assert(buf.size() === length, clue)
180+
181+
// creating the input stream opens the file
182+
{
183+
val bytes = buf.createInputStream().readAllBytes()
184+
verify(mfs, times(1)).open(mc.any[Path]())
185+
assert(bytes.mkString(",") === data.slice(offset, offset + length).mkString(","), clue)
186+
}
187+
188+
// getting a NIO ByteBuffer opens the file again
189+
{
190+
val bytes = buf.nioByteBuffer().array()
191+
verify(mfs, times(2)).open(mc.any[Path]())
192+
assert(bytes.mkString(",") === data.slice(offset, offset + length).mkString(","), clue)
193+
}
194+
195+
// getting a Netty ByteBufs opens the file again and again
196+
assert(buf.convertToNetty().asInstanceOf[ByteBuf].release() === length > 0, clue)
197+
verify(mfs, times(3)).open(mc.any[Path]())
198+
assert(buf.convertToNettyForSsl().asInstanceOf[ByteBuf].release() === length > 0, clue)
199+
verify(mfs, times(4)).open(mc.any[Path]())
200+
}
201+
}
202+
}
203+
161204
test("SPARK-34142: fallback storage API - cleanUp app") {
162205
withTempDir { dir =>
163206
Seq(true, false).foreach { cleanUp =>
@@ -372,6 +415,7 @@ class FallbackStorageSuite extends SparkFunSuite with LocalSparkContext {
372415
}
373416
}
374417
}
418+
375419
class ReadPartialInputStream(val in: FSDataInputStream) extends InputStream
376420
with Seekable with PositionedReadable {
377421
override def read: Int = in.read

core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala

Lines changed: 120 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import scala.concurrent.Future
3232
import io.netty.util.internal.OutOfDirectMemoryError
3333
import org.apache.logging.log4j.Level
3434
import org.mockito.ArgumentMatchers.{any, eq => meq}
35-
import org.mockito.Mockito.{doThrow, mock, times, verify, when}
35+
import org.mockito.Mockito.{doThrow, mock, never, times, verify, when}
3636
import org.mockito.invocation.InvocationOnMock
3737
import org.mockito.stubbing.Answer
3838
import org.roaringbitmap.RoaringBitmap
@@ -300,11 +300,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
300300
}
301301
}
302302

303-
test("successful 3 local + 4 host local + 2 remote reads") {
303+
test("successful 3 local + 4 host local + 2 remote + 2 fallback storage reads") {
304304
val blockManager = createMockBlockManager()
305-
val localBmId = blockManager.blockManagerId
306305

307306
// Make sure blockManager.getBlockData would return the blocks
307+
val localBmId = blockManager.blockManagerId
308308
val localBlocks = Map[BlockId, ManagedBuffer](
309309
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
310310
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
@@ -334,19 +334,37 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
334334
// returning local dir for hostLocalBmId
335335
initHostLocalDirManager(blockManager, hostLocalDirs)
336336

337+
// Make sure fallback storage blocks would return
338+
val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID
339+
val fallbackBlocks = Map[BlockId, ManagedBuffer](
340+
ShuffleBlockId(0, 9, 0) -> createMockManagedBuffer(),
341+
ShuffleBlockId(0, 10, 0) -> createMockManagedBuffer())
342+
fallbackBlocks.foreach { case (blockId, buf) =>
343+
doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
344+
}
345+
337346
val iterator = createShuffleBlockIteratorWithDefaults(
338347
Map(
339348
localBmId -> toBlockList(localBlocks.keys, 1L, 0),
340349
remoteBmId -> toBlockList(remoteBlocks.keys, 1L, 1),
341-
hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)
350+
hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1),
351+
fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L, 1)
342352
),
343353
blockManager = Some(blockManager)
344354
)
345355

346-
// 3 local blocks fetched in initialization
347-
verify(blockManager, times(3)).getLocalBlockData(any())
356+
// 3 local blocks and 2 fallback blocks fetched in initialization
357+
verify(blockManager, times(3 + 2)).getLocalBlockData(any())
358+
359+
// SPARK-55469: but buffer data have never been materialized
360+
fallbackBlocks.values.foreach { mockBuf =>
361+
verify(mockBuf, never()).nioByteBuffer()
362+
verify(mockBuf, never()).createInputStream()
363+
verify(mockBuf, never()).convertToNetty()
364+
verify(mockBuf, never()).convertToNettyForSsl()
365+
}
348366

349-
val allBlocks = localBlocks ++ remoteBlocks ++ hostLocalBlocks
367+
val allBlocks = localBlocks ++ remoteBlocks ++ hostLocalBlocks ++ fallbackBlocks
350368
for (i <- 0 until allBlocks.size) {
351369
assert(iterator.hasNext,
352370
s"iterator should have ${allBlocks.size} elements but actually has $i elements")
@@ -356,14 +374,23 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
356374
val mockBuf = allBlocks(blockId)
357375
verifyBufferRelease(mockBuf, inputStream)
358376
}
377+
assert(!iterator.hasNext)
359378

360379
// 4 host-local locks fetched
361380
verify(blockManager, times(4))
362381
.getHostLocalShuffleData(any(), meq(Array("local-dir")))
363382

364-
// 2 remote blocks are read from the same block manager
383+
// 2 remote blocks are read from the same block manager in one fetch
365384
verifyFetchBlocksInvocationCount(1)
366385
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1)
386+
387+
// SPARK-55469: fallback buffer data have been materialized once
388+
fallbackBlocks.values.foreach { mockBuf =>
389+
verify(mockBuf, never()).nioByteBuffer()
390+
verify(mockBuf, times(1)).createInputStream()
391+
verify(mockBuf, never()).convertToNetty()
392+
verify(mockBuf, never()).convertToNettyForSsl()
393+
}
367394
}
368395

369396
test("error during accessing host local dirs for executors") {
@@ -451,10 +478,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
451478
assert(!iterator.hasNext)
452479
}
453480

454-
test("fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote reads") {
481+
test("fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote + " +
482+
"2 fallback storage reads") {
455483
val blockManager = createMockBlockManager()
456-
val localBmId = blockManager.blockManagerId
484+
457485
// Make sure blockManager.getBlockData would return the merged block
486+
val localBmId = blockManager.blockManagerId
458487
val localBlocks = Seq[BlockId](
459488
ShuffleBlockId(0, 0, 0),
460489
ShuffleBlockId(0, 0, 1),
@@ -465,6 +494,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
465494
doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
466495
}
467496

497+
// Make sure fallback storage would return the merged block
498+
val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID
499+
val fallbackBlocks = Seq[BlockId](
500+
ShuffleBlockId(0, 1, 0),
501+
ShuffleBlockId(0, 1, 1))
502+
val mergedFallbackBlocks = Map[BlockId, ManagedBuffer](
503+
ShuffleBlockBatchId(0, 1, 0, 2) -> createMockManagedBuffer())
504+
mergedFallbackBlocks.foreach { case (blockId, buf) =>
505+
doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
506+
}
507+
468508
// Make sure remote blocks would return the merged block
469509
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
470510
val remoteBlocks = Seq[BlockId](
@@ -496,30 +536,49 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
496536
val iterator = createShuffleBlockIteratorWithDefaults(
497537
Map(
498538
localBmId -> toBlockList(localBlocks, 1L, 0),
539+
fallbackBmId -> toBlockList(fallbackBlocks, 1L, 1),
499540
remoteBmId -> toBlockList(remoteBlocks, 1L, 1),
500541
hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)
501542
),
502543
blockManager = Some(blockManager),
503544
doBatchFetch = true
504545
)
505546

506-
// 3 local blocks batch fetched in initialization
507-
verify(blockManager, times(1)).getLocalBlockData(any())
547+
// 1 local merge block and 1 fallback merge block fetched in initialization
548+
verify(blockManager, times(1 + 1)).getLocalBlockData(any())
508549

509-
val allBlocks = mergedLocalBlocks ++ mergedRemoteBlocks ++ mergedHostLocalBlocks
510-
for (i <- 0 until 3) {
511-
assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements")
550+
// SPARK-55469: but buffer data have never been materialized
551+
mergedFallbackBlocks.values.foreach { mockBuf =>
552+
verify(mockBuf, never()).nioByteBuffer()
553+
verify(mockBuf, never()).createInputStream()
554+
verify(mockBuf, never()).convertToNetty()
555+
verify(mockBuf, never()).convertToNettyForSsl()
556+
}
557+
558+
val allBlocks = mergedLocalBlocks ++ mergedFallbackBlocks ++ mergedRemoteBlocks ++
559+
mergedHostLocalBlocks
560+
for (i <- 0 until 4) {
561+
assert(iterator.hasNext, s"iterator should have 4 elements but actually has $i elements")
512562
val (blockId, inputStream) = iterator.next()
513563
verifyFetchBlocksInvocationCount(1)
514564
// Make sure we release buffers when a wrapped input stream is closed.
515565
val mockBuf = allBlocks(blockId)
516566
verifyBufferRelease(mockBuf, inputStream)
517567
}
568+
assert(!iterator.hasNext)
518569

519-
// 4 host-local locks fetched
570+
// 1 merged host-local locks fetched
520571
verify(blockManager, times(1))
521572
.getHostLocalShuffleData(any(), meq(Array("local-dir")))
522573

574+
// SPARK-55469: merged fallback buffer data have been materialized once
575+
mergedFallbackBlocks.values.foreach { mockBuf =>
576+
verify(mockBuf, never()).nioByteBuffer()
577+
verify(mockBuf, times(1)).createInputStream()
578+
verify(mockBuf, never()).convertToNetty()
579+
verify(mockBuf, never()).convertToNettyForSsl()
580+
}
581+
523582
assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1)
524583
}
525584

@@ -1051,6 +1110,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
10511110
val mockBuf = remoteBlocks(blockId)
10521111
verifyBufferRelease(mockBuf, inputStream)
10531112
}
1113+
assert(!iterator.hasNext)
10541114

10551115
// 1st fetch request (contains 1 block) would fail due to Netty OOM
10561116
// 2nd fetch request retry the block of the 1st fetch request
@@ -1091,6 +1151,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
10911151
val mockBuf = remoteBlocks(blockId)
10921152
verifyBufferRelease(mockBuf, inputStream)
10931153
}
1154+
assert(!iterator.hasNext)
10941155

10951156
// 1st fetch request (contains 3 blocks) would fail on the someone block due to Netty OOM
10961157
// but succeed for the remaining blocks
@@ -2037,9 +2098,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
20372098

20382099
test("SPARK-52395: Fast fail when failed to get host local dirs") {
20392100
val blockManager = createMockBlockManager()
2040-
val localBmId = blockManager.blockManagerId
20412101

20422102
// Make sure blockManager.getBlockData would return the blocks
2103+
val localBmId = blockManager.blockManagerId
20432104
val localBlocks = Map[BlockId, ManagedBuffer](
20442105
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
20452106
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer())
@@ -2076,4 +2137,46 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
20762137
assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0))
20772138
assert(!iterator.hasNext)
20782139
}
2140+
2141+
test("SPARK-52395: Fast fail when failed to get fallback storage blocks") {
2142+
val blockManager = createMockBlockManager()
2143+
2144+
// Make sure blockManager.getBlockData would return the blocks
2145+
val localBmId = blockManager.blockManagerId
2146+
val localBlocks = Map[BlockId, ManagedBuffer](
2147+
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
2148+
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer())
2149+
localBlocks.foreach { case (blockId, buf) =>
2150+
doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
2151+
}
2152+
2153+
// Make sure fallback storage would return the blocks
2154+
val fallbackBmId = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID
2155+
val fallbackBlocks = Map[BlockId, ManagedBuffer](
2156+
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer(),
2157+
ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer())
2158+
fallbackBlocks.take(1).foreach { case (blockId, buf) =>
2159+
doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
2160+
}
2161+
fallbackBlocks.takeRight(1).foreach { case (blockId, _) =>
2162+
doThrow(new RuntimeException("Cannot read from fallback storage"))
2163+
.when(blockManager).getLocalBlockData(meq(blockId))
2164+
}
2165+
2166+
val iterator = createShuffleBlockIteratorWithDefaults(
2167+
Map(
2168+
localBmId -> toBlockList(localBlocks.keys, 1L, 0),
2169+
fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L, 1)
2170+
),
2171+
blockManager = Some(blockManager)
2172+
)
2173+
2174+
// Fetch failure should be placed in the head of results, exception should be thrown for the
2175+
// 1st instance.
2176+
intercept[FetchFailedException] { iterator.next() }
2177+
assert(iterator.next()._1 === ShuffleBlockId(0, 0, 0))
2178+
assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0))
2179+
assert(iterator.next()._1 === ShuffleBlockId(0, 2, 0))
2180+
assert(!iterator.hasNext)
2181+
}
20792182
}

0 commit comments

Comments
 (0)