2020import java .io .File ;
2121import java .util .List ;
2222import java .util .Map ;
23+ import java .util .concurrent .CompletableFuture ;
24+ import java .util .concurrent .TimeUnit ;
25+ import java .util .concurrent .atomic .AtomicBoolean ;
26+ import java .util .function .Supplier ;
2327import java .util .stream .Stream ;
2428
2529import com .google .common .collect .Lists ;
2630import com .google .common .collect .Maps ;
27- import org .junit .jupiter .api .AfterAll ;
28- import org .junit .jupiter .api .BeforeAll ;
31+ import org .awaitility .Awaitility ;
32+ import org .junit .jupiter .api .AfterEach ;
33+ import org .junit .jupiter .api .BeforeEach ;
34+ import org .junit .jupiter .api .Test ;
2935import org .junit .jupiter .api .io .TempDir ;
3036import org .junit .jupiter .params .ParameterizedTest ;
3137import org .junit .jupiter .params .provider .Arguments ;
4450import org .apache .uniffle .common .ShuffleServerInfo ;
4551import org .apache .uniffle .common .rpc .ServerType ;
4652import org .apache .uniffle .coordinator .CoordinatorConf ;
47- import org .apache .uniffle .coordinator .CoordinatorServer ;
4853import org .apache .uniffle .server .MockedGrpcServer ;
4954import org .apache .uniffle .server .MockedShuffleServer ;
50- import org .apache .uniffle .server .ShuffleServer ;
5155import org .apache .uniffle .server .ShuffleServerConf ;
5256import org .apache .uniffle .storage .util .StorageType ;
5357
5458import static org .junit .jupiter .api .Assertions .assertEquals ;
5559import static org .junit .jupiter .api .Assertions .fail ;
5660
5761public class RpcClientRetryTest extends ShuffleReadWriteBase {
58-
59- private static ShuffleServerInfo shuffleServerInfo0 ;
60- private static ShuffleServerInfo shuffleServerInfo1 ;
61- private static ShuffleServerInfo shuffleServerInfo2 ;
62+ private static List <ShuffleServerInfo > grpcShuffleServerInfoList = Lists .newArrayList ();
6263 private static MockedShuffleWriteClientImpl shuffleWriteClientImpl ;
6364
6465 private ShuffleClientFactory .ReadClientBuilder baseReadBuilder (StorageType storageType ) {
@@ -73,9 +74,9 @@ private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(StorageType stora
7374 .readBufferSize (1000 );
7475 }
7576
76- public static MockedShuffleServer createMockedShuffleServer (int id , File tmpDir )
77- throws Exception {
78- ShuffleServerConf shuffleServerConf = getShuffleServerConf (ServerType . GRPC );
77+ public static MockedShuffleServer createMockedShuffleServer (
78+ int id , File tmpDir , ServerType serverType ) throws Exception {
79+ ShuffleServerConf shuffleServerConf = getShuffleServerConf (serverType );
7980 File dataDir1 = new File (tmpDir , id + "_1" );
8081 File dataDir2 = new File (tmpDir , id + "_2" );
8182 String basePath = dataDir1 .getAbsolutePath () + "," + dataDir2 .getAbsolutePath ();
@@ -85,46 +86,70 @@ public static MockedShuffleServer createMockedShuffleServer(int id, File tmpDir)
8586 shuffleServerConf .set (ShuffleServerConf .SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE , 15.0 );
8687 shuffleServerConf .set (ShuffleServerConf .SERVER_BUFFER_CAPACITY , 600L );
8788 shuffleServerConf .set (ShuffleServerConf .SINGLE_BUFFER_FLUSH_BLOCKS_NUM_THRESHOLD , 1 );
89+ shuffleServerConf .set (ShuffleServerConf .RPC_SERVER_PORT , 0 );
8890 return new MockedShuffleServer (shuffleServerConf );
8991 }
9092
91- @ BeforeAll
92- public static void initCluster (@ TempDir File tmpDir ) throws Exception {
93+ @ BeforeEach
94+ public void initCluster (@ TempDir File tmpDir ) throws Exception {
9395 CoordinatorConf coordinatorConf = getCoordinatorConf ();
9496 createCoordinatorServer (coordinatorConf );
97+ for (int i = 0 ; i < 1 ; i ++) {
98+ grpcShuffleServers .add (createMockedShuffleServer (i , tmpDir , ServerType .GRPC ));
99+ }
95100
96- grpcShuffleServers .add (createMockedShuffleServer (0 , tmpDir ));
97- grpcShuffleServers .add (createMockedShuffleServer (1 , tmpDir ));
98- grpcShuffleServers .add (createMockedShuffleServer (2 , tmpDir ));
101+ startServers ();
99102
100- shuffleServerInfo0 =
101- new ShuffleServerInfo (
102- String .format ("127.0.0.1-%s" , grpcShuffleServers .get (0 ).getGrpcPort ()),
103- grpcShuffleServers .get (0 ).getIp (),
104- grpcShuffleServers .get (0 ).getGrpcPort ());
105- shuffleServerInfo1 =
106- new ShuffleServerInfo (
107- String .format ("127.0.0.1-%s" , grpcShuffleServers .get (1 ).getGrpcPort ()),
108- grpcShuffleServers .get (1 ).getIp (),
109- grpcShuffleServers .get (1 ).getGrpcPort ());
110- shuffleServerInfo2 =
111- new ShuffleServerInfo (
112- String .format ("127.0.0.1-%s" , grpcShuffleServers .get (2 ).getGrpcPort ()),
113- grpcShuffleServers .get (2 ).getIp (),
114- grpcShuffleServers .get (2 ).getGrpcPort ());
115- for (CoordinatorServer coordinator : coordinators ) {
116- coordinator .start ();
117- }
118- for (ShuffleServer shuffleServer : grpcShuffleServers ) {
119- shuffleServer .start ();
103+ for (int i = 0 ; i < 1 ; i ++) {
104+ grpcShuffleServerInfoList .add (
105+ new ShuffleServerInfo (
106+ String .format ("127.0.0.1-%s" , grpcShuffleServers .get (i ).getGrpcPort ()),
107+ grpcShuffleServers .get (i ).getIp (),
108+ grpcShuffleServers .get (i ).getGrpcPort ()));
120109 }
121110 }
122111
123- @ AfterAll
124- public static void cleanEnv () throws Exception {
112+ @ Test
113+ public void testCancelGrpc () throws InterruptedException {
114+ String testAppId = "testCancelGrpc" ;
115+ registerShuffleServer (testAppId , 1 , 1 , 1 , false , 3000 );
116+ Map <Long , byte []> expectedData = Maps .newHashMap ();
117+ Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap .bitmapOf ();
118+ List <ShuffleBlockInfo > blocks =
119+ createShuffleBlockList (
120+ 0 ,
121+ 0 ,
122+ 0 ,
123+ 2 ,
124+ 25 ,
125+ blockIdBitmap ,
126+ expectedData ,
127+ Lists .newArrayList (grpcShuffleServerInfoList .get (0 )));
128+ AtomicBoolean isCancel = new AtomicBoolean (false );
129+ Supplier <Boolean > needCancelRequest = () -> isCancel .get ();
130+ SendShuffleDataResult result =
131+ shuffleWriteClientImpl .sendShuffleData (testAppId , blocks , needCancelRequest );
132+ assertEquals (2 , result .getSuccessBlockIds ().size ());
133+
134+ enableFirstNSendDataRequestsToFail (2 );
135+ System .out .println ("1====================" );
136+ CompletableFuture <SendShuffleDataResult > future =
137+ CompletableFuture .supplyAsync (
138+ () -> shuffleWriteClientImpl .sendShuffleData (testAppId , blocks , needCancelRequest ));
139+ // this ensure isCancel takes effect in rpc retry
140+ TimeUnit .SECONDS .sleep (1 );
141+ isCancel .set (true );
142+ Awaitility .await ()
143+ .atMost (5 , TimeUnit .SECONDS )
144+ .until (() -> future .isDone () && future .get ().getSuccessBlockIds ().size () == 0 );
145+ }
146+
147+ @ AfterEach
148+ public void cleanEnv () throws Exception {
125149 if (shuffleWriteClientImpl != null ) {
126150 shuffleWriteClientImpl .close ();
127151 }
152+ grpcShuffleServerInfoList .clear ();
128153 shutdownServers ();
129154 }
130155
@@ -140,22 +165,16 @@ private static Stream<Arguments> testRpcRetryLogicProvider() {
140165 @ MethodSource ("testRpcRetryLogicProvider" )
141166 public void testRpcRetryLogic (StorageType storageType ) {
142167 String testAppId = "testRpcRetryLogic" ;
143- registerShuffleServer (testAppId , 3 , 2 , 2 , true );
168+ registerShuffleServer (testAppId , 3 , 2 , 2 , true , 1000 );
144169 Map <Long , byte []> expectedData = Maps .newHashMap ();
145170 Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap .bitmapOf ();
146171
147172 List <ShuffleBlockInfo > blocks =
148173 createShuffleBlockList (
149- 0 ,
150- 0 ,
151- 0 ,
152- 3 ,
153- 25 ,
154- blockIdBitmap ,
155- expectedData ,
156- Lists .newArrayList (shuffleServerInfo0 , shuffleServerInfo1 , shuffleServerInfo2 ));
174+ 0 , 0 , 0 , 3 , 25 , blockIdBitmap , expectedData , grpcShuffleServerInfoList );
157175
158- SendShuffleDataResult result = shuffleWriteClientImpl .sendShuffleData (testAppId , blocks );
176+ SendShuffleDataResult result =
177+ shuffleWriteClientImpl .sendShuffleData (testAppId , blocks , () -> false );
159178 Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap .bitmapOf ();
160179 Roaring64NavigableMap successfulBlockIdBitmap = Roaring64NavigableMap .bitmapOf ();
161180 for (Long blockId : result .getSuccessBlockIds ()) {
@@ -174,8 +193,7 @@ public void testRpcRetryLogic(StorageType storageType) {
174193 .appId (testAppId )
175194 .blockIdBitmap (blockIdBitmap )
176195 .taskIdBitmap (taskIdBitmap )
177- .shuffleServerInfoList (
178- Lists .newArrayList (shuffleServerInfo0 , shuffleServerInfo1 , shuffleServerInfo2 ))
196+ .shuffleServerInfoList (grpcShuffleServerInfoList )
179197 .retryMax (3 )
180198 .retryIntervalMax (1 )
181199 .build ();
@@ -195,8 +213,7 @@ public void testRpcRetryLogic(StorageType storageType) {
195213 .appId (testAppId )
196214 .blockIdBitmap (blockIdBitmap )
197215 .taskIdBitmap (taskIdBitmap )
198- .shuffleServerInfoList (
199- Lists .newArrayList (shuffleServerInfo0 , shuffleServerInfo1 , shuffleServerInfo2 ))
216+ .shuffleServerInfoList (grpcShuffleServerInfoList )
200217 .retryMax (3 )
201218 .retryIntervalMax (1 )
202219 .build ();
@@ -208,39 +225,55 @@ public void testRpcRetryLogic(StorageType storageType) {
208225 }
209226
210227 private static void enableFirstNReadRequestsToFail (int failedCount ) {
211- for (ShuffleServer server : grpcShuffleServers ) {
212- ((MockedGrpcServer ) server .getServer ())
213- .getService ()
214- .enableFirstNReadRequestToFail (failedCount );
215- }
228+ Lists .newArrayList (grpcShuffleServers , nettyShuffleServers ).stream ()
229+ .flatMap (List ::stream )
230+ .forEach (
231+ server ->
232+ ((MockedGrpcServer ) server .getServer ())
233+ .getService ()
234+ .enableFirstNReadRequestToFail (failedCount ));
235+ }
236+
237+ private static void enableFirstNSendDataRequestsToFail (int failedCount ) {
238+ Lists .newArrayList (grpcShuffleServers , nettyShuffleServers ).stream ()
239+ .flatMap (List ::stream )
240+ .forEach (
241+ server ->
242+ ((MockedGrpcServer ) server .getServer ())
243+ .getService ()
244+ .enableFirstNSendDataRequestToFail (failedCount ));
216245 }
217246
218247 private static void disableFirstNReadRequestsToFail () {
219- for (ShuffleServer server : grpcShuffleServers ) {
220- ((MockedGrpcServer ) server .getServer ()).getService ().resetFirstNReadRequestToFail ();
221- }
248+ Lists .newArrayList (grpcShuffleServers , nettyShuffleServers ).stream ()
249+ .flatMap (List ::stream )
250+ .forEach (
251+ server ->
252+ ((MockedGrpcServer ) server .getServer ())
253+ .getService ()
254+ .resetFirstNReadRequestToFail ());
222255 }
223256
224257 static class MockedShuffleWriteClientImpl extends ShuffleWriteClientImpl {
225258 MockedShuffleWriteClientImpl (ShuffleClientFactory .WriteClientBuilder builder ) {
226259 super (builder );
227260 }
228-
229- public SendShuffleDataResult sendShuffleData (
230- String appId , List <ShuffleBlockInfo > shuffleBlockInfoList ) {
231- return super .sendShuffleData (appId , shuffleBlockInfoList , () -> false );
232- }
233261 }
234262
235263 private void registerShuffleServer (
236- String testAppId , int replica , int replicaWrite , int replicaRead , boolean replicaSkip ) {
264+ String testAppId ,
265+ int replica ,
266+ int replicaWrite ,
267+ int replicaRead ,
268+ boolean replicaSkip ,
269+ long retryIntervalMs ) {
237270
238271 shuffleWriteClientImpl =
239272 new MockedShuffleWriteClientImpl (
240273 ShuffleClientFactory .newWriteBuilder ()
241274 .clientType (ClientType .GRPC .name ())
242275 .retryMax (3 )
243- .retryIntervalMax (1000 )
276+ .retryIntervalMax (retryIntervalMs )
244277 .heartBeatThreadNum (1 )
245278 .replica (replica )
246279 .replicaWrite (replicaWrite )
@@ -252,12 +285,9 @@ private void registerShuffleServer(
252285 .unregisterTimeSec (10 )
253286 .unregisterRequestTimeSec (10 ));
254287
255- List <ShuffleServerInfo > allServers =
256- Lists .newArrayList (shuffleServerInfo0 , shuffleServerInfo1 , shuffleServerInfo2 );
257-
258288 for (int i = 0 ; i < replica ; i ++) {
259289 shuffleWriteClientImpl .registerShuffle (
260- allServers .get (i ),
290+ grpcShuffleServerInfoList .get (i ),
261291 testAppId ,
262292 0 ,
263293 Lists .newArrayList (new PartitionRange (0 , 0 )),
0 commit comments