@@ -33,7 +33,7 @@ import com.google.common.io.ByteStreams
3333import io .netty .util .internal .OutOfDirectMemoryError
3434import org .apache .logging .log4j .Level
3535import org .mockito .ArgumentMatchers .{any , eq => meq }
36- import org .mockito .Mockito .{doThrow , mock , times , verify , when }
36+ import org .mockito .Mockito .{doThrow , mock , never , times , verify , when }
3737import org .mockito .invocation .InvocationOnMock
3838import org .mockito .stubbing .Answer
3939import org .roaringbitmap .RoaringBitmap
@@ -298,11 +298,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
298298 }
299299 }
300300
301- test(" successful 3 local + 4 host local + 2 remote reads" ) {
301+ test(" successful 3 local + 4 host local + 2 remote + 2 fallback storage reads" ) {
302302 val blockManager = createMockBlockManager()
303- val localBmId = blockManager.blockManagerId
304303
305304 // Make sure blockManager.getBlockData would return the blocks
305+ val localBmId = blockManager.blockManagerId
306306 val localBlocks = Map [BlockId , ManagedBuffer ](
307307 ShuffleBlockId (0 , 0 , 0 ) -> createMockManagedBuffer(),
308308 ShuffleBlockId (0 , 1 , 0 ) -> createMockManagedBuffer(),
@@ -332,19 +332,37 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
332332 // returning local dir for hostLocalBmId
333333 initHostLocalDirManager(blockManager, hostLocalDirs)
334334
335+ // Make sure fallback storage blocks would return
336+ val fallbackBmId = FallbackStorage .FALLBACK_BLOCK_MANAGER_ID
337+ val fallbackBlocks = Map [BlockId , ManagedBuffer ](
338+ ShuffleBlockId (0 , 9 , 0 ) -> createMockManagedBuffer(),
339+ ShuffleBlockId (0 , 10 , 0 ) -> createMockManagedBuffer())
340+ fallbackBlocks.foreach { case (blockId, buf) =>
341+ doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
342+ }
343+
335344 val iterator = createShuffleBlockIteratorWithDefaults(
336345 Map (
337346 localBmId -> toBlockList(localBlocks.keys, 1L , 0 ),
338347 remoteBmId -> toBlockList(remoteBlocks.keys, 1L , 1 ),
339- hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L , 1 )
348+ hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L , 1 ),
349+ fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L , 1 )
340350 ),
341351 blockManager = Some (blockManager)
342352 )
343353
344- // 3 local blocks fetched in initialization
345- verify(blockManager, times(3 )).getLocalBlockData(any())
354+ // 3 local blocks and 2 fallback blocks fetched in initialization
355+ verify(blockManager, times(3 + 2 )).getLocalBlockData(any())
346356
347- val allBlocks = localBlocks ++ remoteBlocks ++ hostLocalBlocks
357+ // SPARK-55469: but buffer data have never been materialized
358+ fallbackBlocks.values.foreach { mockBuf =>
359+ verify(mockBuf, never()).nioByteBuffer()
360+ verify(mockBuf, never()).createInputStream()
361+ verify(mockBuf, never()).convertToNetty()
362+ verify(mockBuf, never()).convertToNettyForSsl()
363+ }
364+
365+ val allBlocks = localBlocks ++ remoteBlocks ++ hostLocalBlocks ++ fallbackBlocks
348366 for (i <- 0 until allBlocks.size) {
349367 assert(iterator.hasNext,
350368 s " iterator should have ${allBlocks.size} elements but actually has $i elements " )
@@ -354,14 +372,23 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
354372 val mockBuf = allBlocks(blockId)
355373 verifyBufferRelease(mockBuf, inputStream)
356374 }
375+ assert(! iterator.hasNext)
357376
358377 // 4 host-local locks fetched
359378 verify(blockManager, times(4 ))
360379 .getHostLocalShuffleData(any(), meq(Array (" local-dir" )))
361380
362- // 2 remote blocks are read from the same block manager
381+ // 2 remote blocks are read from the same block manager in one fetch
363382 verifyFetchBlocksInvocationCount(1 )
364383 assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1 )
384+
385+ // SPARK-55469: fallback buffer data have been materialized once
386+ fallbackBlocks.values.foreach { mockBuf =>
387+ verify(mockBuf, never()).nioByteBuffer()
388+ verify(mockBuf, times(1 )).createInputStream()
389+ verify(mockBuf, never()).convertToNetty()
390+ verify(mockBuf, never()).convertToNettyForSsl()
391+ }
365392 }
366393
367394 test(" error during accessing host local dirs for executors" ) {
@@ -447,10 +474,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
447474 assert(! iterator.hasNext)
448475 }
449476
450- test(" fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote reads" ) {
477+ test(" fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote + " +
478+ " 2 fallback storage reads" ) {
451479 val blockManager = createMockBlockManager()
452- val localBmId = blockManager.blockManagerId
480+
453481 // Make sure blockManager.getBlockData would return the merged block
482+ val localBmId = blockManager.blockManagerId
454483 val localBlocks = Seq [BlockId ](
455484 ShuffleBlockId (0 , 0 , 0 ),
456485 ShuffleBlockId (0 , 0 , 1 ),
@@ -461,6 +490,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
461490 doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
462491 }
463492
493+ // Make sure fallback storage would return the merged block
494+ val fallbackBmId = FallbackStorage .FALLBACK_BLOCK_MANAGER_ID
495+ val fallbackBlocks = Seq [BlockId ](
496+ ShuffleBlockId (0 , 1 , 0 ),
497+ ShuffleBlockId (0 , 1 , 1 ))
498+ val mergedFallbackBlocks = Map [BlockId , ManagedBuffer ](
499+ ShuffleBlockBatchId (0 , 1 , 0 , 2 ) -> createMockManagedBuffer())
500+ mergedFallbackBlocks.foreach { case (blockId, buf) =>
501+ doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
502+ }
503+
464504 // Make sure remote blocks would return the merged block
465505 val remoteBmId = BlockManagerId (" test-client-1" , " test-client-1" , 2 )
466506 val remoteBlocks = Seq [BlockId ](
@@ -492,30 +532,49 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
492532 val iterator = createShuffleBlockIteratorWithDefaults(
493533 Map (
494534 localBmId -> toBlockList(localBlocks, 1L , 0 ),
535+ fallbackBmId -> toBlockList(fallbackBlocks, 1L , 1 ),
495536 remoteBmId -> toBlockList(remoteBlocks, 1L , 1 ),
496537 hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L , 1 )
497538 ),
498539 blockManager = Some (blockManager),
499540 doBatchFetch = true
500541 )
501542
502- // 3 local blocks batch fetched in initialization
503- verify(blockManager, times(1 )).getLocalBlockData(any())
543+ // 1 local merge block and 1 fallback merge block fetched in initialization
544+ verify(blockManager, times(1 + 1 )).getLocalBlockData(any())
545+
546+ // SPARK-55469: but buffer data have never been materialized
547+ mergedFallbackBlocks.values.foreach { mockBuf =>
548+ verify(mockBuf, never()).nioByteBuffer()
549+ verify(mockBuf, never()).createInputStream()
550+ verify(mockBuf, never()).convertToNetty()
551+ verify(mockBuf, never()).convertToNettyForSsl()
552+ }
504553
505- val allBlocks = mergedLocalBlocks ++ mergedRemoteBlocks ++ mergedHostLocalBlocks
506- for (i <- 0 until 3 ) {
507- assert(iterator.hasNext, s " iterator should have 3 elements but actually has $i elements " )
554+ val allBlocks = mergedLocalBlocks ++ mergedFallbackBlocks ++ mergedRemoteBlocks ++
555+ mergedHostLocalBlocks
556+ for (i <- 0 until 4 ) {
557+ assert(iterator.hasNext, s " iterator should have 4 elements but actually has $i elements " )
508558 val (blockId, inputStream) = iterator.next()
509559 verifyFetchBlocksInvocationCount(1 )
510560 // Make sure we release buffers when a wrapped input stream is closed.
511561 val mockBuf = allBlocks(blockId)
512562 verifyBufferRelease(mockBuf, inputStream)
513563 }
564+ assert(! iterator.hasNext)
514565
515- // 4 host-local locks fetched
566+ // 1 merged host-local locks fetched
516567 verify(blockManager, times(1 ))
517568 .getHostLocalShuffleData(any(), meq(Array (" local-dir" )))
518569
570+ // SPARK-55469: merged fallback buffer data have been materialized once
571+ mergedFallbackBlocks.values.foreach { mockBuf =>
572+ verify(mockBuf, never()).nioByteBuffer()
573+ verify(mockBuf, times(1 )).createInputStream()
574+ verify(mockBuf, never()).convertToNetty()
575+ verify(mockBuf, never()).convertToNettyForSsl()
576+ }
577+
519578 assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1 )
520579 }
521580
@@ -1046,6 +1105,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
10461105 val mockBuf = remoteBlocks(blockId)
10471106 verifyBufferRelease(mockBuf, inputStream)
10481107 }
1108+ assert(! iterator.hasNext)
10491109
10501110 // 1st fetch request (contains 1 block) would fail due to Netty OOM
10511111 // 2nd fetch request retry the block of the 1st fetch request
@@ -1086,6 +1146,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
10861146 val mockBuf = remoteBlocks(blockId)
10871147 verifyBufferRelease(mockBuf, inputStream)
10881148 }
1149+ assert(! iterator.hasNext)
10891150
10901151 // 1st fetch request (contains 3 blocks) would fail on the someone block due to Netty OOM
10911152 // but succeed for the remaining blocks
0 commit comments