@@ -19,6 +19,7 @@ package org.apache.spark.storage
1919
2020import java .io ._
2121import java .nio .ByteBuffer
22+ import java .util
2223import java .util .UUID
2324import java .util .concurrent .{CompletableFuture , Semaphore }
2425import java .util .zip .CheckedInputStream
@@ -153,6 +154,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
153154 val in = mock(classOf [InputStream ])
154155 when(in.read(any())).thenReturn(1 )
155156 when(in.read(any(), any(), any())).thenReturn(1 )
157+ val buf = ByteBuffer .allocate(size)
158+ util.Arrays .fill(buf.array(), 1 .byteValue)
159+ when(mockManagedBuffer.nioByteBuffer()).thenReturn(buf)
156160 when(mockManagedBuffer.createInputStream()).thenReturn(in)
157161 when(mockManagedBuffer.size()).thenReturn(size)
158162 mockManagedBuffer
@@ -342,7 +346,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
342346 ShuffleBlockId (0 , 9 , 0 ) -> createMockManagedBuffer(),
343347 ShuffleBlockId (0 , 10 , 0 ) -> createMockManagedBuffer())
344348 fallbackBlocks.foreach { case (blockId, buf) =>
345- doReturn(buf).when(blockManager).getLocalBlockData (meq(blockId))
349+ doReturn(buf).when(blockManager).getFallbackStorageBlockData (meq(blockId))
346350 }
347351
348352 val iterator = createShuffleBlockIteratorWithDefaults(
@@ -355,9 +359,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
355359 blockManager = Some (blockManager)
356360 )
357361
358- // 3 local blocks and 2 fallback blocks fetched in initialization
359- verify(blockManager, times(3 + 2 )).getLocalBlockData(any())
362+ // 3 local blocks fetched in initialization
363+ verify(blockManager, times(3 )).getLocalBlockData(any())
360364
365+ // 2 fallback storage blocks fetched in initialization
366+ verify(blockManager, times(2 )).getFallbackStorageBlockData(any())
361367 // SPARK-55469: but buffer data have never been materialized
362368 fallbackBlocks.values.foreach { mockBuf =>
363369 verify(mockBuf, never()).nioByteBuffer()
@@ -374,7 +380,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
374380
375381 // Make sure we release buffers when a wrapped input stream is closed.
376382 val mockBuf = allBlocks(blockId)
377- verifyBufferRelease(mockBuf, inputStream)
383+ if (! fallbackBlocks.contains(blockId)) {
384+ verifyBufferRelease(mockBuf, inputStream)
385+ }
378386 }
379387 assert(! iterator.hasNext)
380388
@@ -388,8 +396,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
388396
389397 // SPARK-55469: fallback buffer data have been materialized once
390398 fallbackBlocks.values.foreach { mockBuf =>
391- verify(mockBuf, never( )).nioByteBuffer()
392- verify(mockBuf, times( 1 )).createInputStream()
399+ verify(mockBuf, times( 1 )).nioByteBuffer()
400+ verify(mockBuf, never( )).createInputStream()
393401 verify(mockBuf, never()).convertToNetty()
394402 verify(mockBuf, never()).convertToNettyForSsl()
395403 }
@@ -504,7 +512,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
504512 val mergedFallbackBlocks = Map [BlockId , ManagedBuffer ](
505513 ShuffleBlockBatchId (0 , 1 , 0 , 2 ) -> createMockManagedBuffer())
506514 mergedFallbackBlocks.foreach { case (blockId, buf) =>
507- doReturn(buf).when(blockManager).getLocalBlockData (meq(blockId))
515+ doReturn(buf).when(blockManager).getFallbackStorageBlockData (meq(blockId))
508516 }
509517
510518 // Make sure remote blocks would return the merged block
@@ -546,9 +554,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
546554 doBatchFetch = true
547555 )
548556
549- // 1 local merge block and 1 fallback merge block fetched in initialization
550- verify(blockManager, times(1 + 1 )).getLocalBlockData(any())
557+ // 1 local merge block fetched in initialization
558+ verify(blockManager, times(1 )).getLocalBlockData(any())
551559
560+ // 1 fallback merge block fetched in initialization
561+ verify(blockManager, times(1 )).getFallbackStorageBlockData(any())
552562 // SPARK-55469: but buffer data have never been materialized
553563 mergedFallbackBlocks.values.foreach { mockBuf =>
554564 verify(mockBuf, never()).nioByteBuffer()
@@ -565,23 +575,27 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
565575 verifyFetchBlocksInvocationCount(1 )
566576 // Make sure we release buffers when a wrapped input stream is closed.
567577 val mockBuf = allBlocks(blockId)
568- verifyBufferRelease(mockBuf, inputStream)
578+ if (! mergedFallbackBlocks.contains(blockId)) {
579+ verifyBufferRelease(mockBuf, inputStream)
580+ }
569581 }
570582 assert(! iterator.hasNext)
571583
572584 // 1 merged host-local locks fetched
573585 verify(blockManager, times(1 ))
574586 .getHostLocalShuffleData(any(), meq(Array (" local-dir" )))
575587
588+ // 1 merged remote block is read from the same block manager
589+ verifyFetchBlocksInvocationCount(1 )
590+ assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1 )
591+
576592 // SPARK-55469: merged fallback buffer data have been materialized once
577593 mergedFallbackBlocks.values.foreach { mockBuf =>
578- verify(mockBuf, never( )).nioByteBuffer()
579- verify(mockBuf, times( 1 )).createInputStream()
594+ verify(mockBuf, times( 1 )).nioByteBuffer()
595+ verify(mockBuf, never( )).createInputStream()
580596 verify(mockBuf, never()).convertToNetty()
581597 verify(mockBuf, never()).convertToNettyForSsl()
582598 }
583-
584- assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1 )
585599 }
586600
587601 test(" fetch continuous blocks in batch should respect maxBytesInFlight" ) {
@@ -2139,46 +2153,4 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
21392153 assert(iterator.next()._1 === ShuffleBlockId (0 , 1 , 0 ))
21402154 assert(! iterator.hasNext)
21412155 }
2142-
2143- test(" Fast fail when failed to get fallback storage blocks" ) {
2144- val blockManager = createMockBlockManager()
2145-
2146- // Make sure blockManager.getBlockData would return the blocks
2147- val localBmId = blockManager.blockManagerId
2148- val localBlocks = Map [BlockId , ManagedBuffer ](
2149- ShuffleBlockId (0 , 0 , 0 ) -> createMockManagedBuffer(),
2150- ShuffleBlockId (0 , 1 , 0 ) -> createMockManagedBuffer())
2151- localBlocks.foreach { case (blockId, buf) =>
2152- doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
2153- }
2154-
2155- // Make sure fallback storage would return the blocks
2156- val fallbackBmId = FallbackStorage .FALLBACK_BLOCK_MANAGER_ID
2157- val fallbackBlocks = Map [BlockId , ManagedBuffer ](
2158- ShuffleBlockId (0 , 2 , 0 ) -> createMockManagedBuffer(),
2159- ShuffleBlockId (0 , 3 , 0 ) -> createMockManagedBuffer())
2160- fallbackBlocks.take(1 ).foreach { case (blockId, buf) =>
2161- doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId))
2162- }
2163- fallbackBlocks.takeRight(1 ).foreach { case (blockId, _) =>
2164- doThrow(new RuntimeException (" Cannot read from fallback storage" ))
2165- .when(blockManager).getLocalBlockData(meq(blockId))
2166- }
2167-
2168- val iterator = createShuffleBlockIteratorWithDefaults(
2169- Map (
2170- localBmId -> toBlockList(localBlocks.keys, 1L , 0 ),
2171- fallbackBmId -> toBlockList(fallbackBlocks.keys, 1L , 1 )
2172- ),
2173- blockManager = Some (blockManager)
2174- )
2175-
2176- // Fetch failure should be placed in the head of results, exception should be thrown for the
2177- // 1st instance.
2178- intercept[FetchFailedException ] { iterator.next() }
2179- assert(iterator.next()._1 === ShuffleBlockId (0 , 0 , 0 ))
2180- assert(iterator.next()._1 === ShuffleBlockId (0 , 1 , 0 ))
2181- assert(iterator.next()._1 === ShuffleBlockId (0 , 2 , 0 ))
2182- assert(! iterator.hasNext)
2183- }
21842156}
0 commit comments