Skip to content

Commit 083f508

Browse files
committed
Eagerly cancel rpc request
1 parent 01f6e25 commit 083f508

File tree

6 files changed

+130
-76
lines changed

6 files changed

+130
-76
lines changed

client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ private boolean sendShuffleDataAsync(
200200
stageAttemptNumber,
201201
retryMax,
202202
retryIntervalMax,
203-
shuffleIdToBlocks);
203+
shuffleIdToBlocks,
204+
needCancelRequest);
204205
long s = System.currentTimeMillis();
205206
RssSendShuffleDataResponse response =
206207
getShuffleServerClient(ssi).sendShuffleData(request);

integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java

Lines changed: 101 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,18 @@
2020
import java.io.File;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.concurrent.CompletableFuture;
24+
import java.util.concurrent.TimeUnit;
25+
import java.util.concurrent.atomic.AtomicBoolean;
26+
import java.util.function.Supplier;
2327
import java.util.stream.Stream;
2428

2529
import com.google.common.collect.Lists;
2630
import com.google.common.collect.Maps;
27-
import org.junit.jupiter.api.AfterAll;
28-
import org.junit.jupiter.api.BeforeAll;
31+
import org.awaitility.Awaitility;
32+
import org.junit.jupiter.api.AfterEach;
33+
import org.junit.jupiter.api.BeforeEach;
34+
import org.junit.jupiter.api.Test;
2935
import org.junit.jupiter.api.io.TempDir;
3036
import org.junit.jupiter.params.ParameterizedTest;
3137
import org.junit.jupiter.params.provider.Arguments;
@@ -44,21 +50,16 @@
4450
import org.apache.uniffle.common.ShuffleServerInfo;
4551
import org.apache.uniffle.common.rpc.ServerType;
4652
import org.apache.uniffle.coordinator.CoordinatorConf;
47-
import org.apache.uniffle.coordinator.CoordinatorServer;
4853
import org.apache.uniffle.server.MockedGrpcServer;
4954
import org.apache.uniffle.server.MockedShuffleServer;
50-
import org.apache.uniffle.server.ShuffleServer;
5155
import org.apache.uniffle.server.ShuffleServerConf;
5256
import org.apache.uniffle.storage.util.StorageType;
5357

5458
import static org.junit.jupiter.api.Assertions.assertEquals;
5559
import static org.junit.jupiter.api.Assertions.fail;
5660

5761
public class RpcClientRetryTest extends ShuffleReadWriteBase {
58-
59-
private static ShuffleServerInfo shuffleServerInfo0;
60-
private static ShuffleServerInfo shuffleServerInfo1;
61-
private static ShuffleServerInfo shuffleServerInfo2;
62+
private static List<ShuffleServerInfo> grpcShuffleServerInfoList = Lists.newArrayList();
6263
private static MockedShuffleWriteClientImpl shuffleWriteClientImpl;
6364

6465
private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(StorageType storageType) {
@@ -73,9 +74,9 @@ private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(StorageType stora
7374
.readBufferSize(1000);
7475
}
7576

76-
public static MockedShuffleServer createMockedShuffleServer(int id, File tmpDir)
77-
throws Exception {
78-
ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC);
77+
public static MockedShuffleServer createMockedShuffleServer(
78+
int id, File tmpDir, ServerType serverType) throws Exception {
79+
ShuffleServerConf shuffleServerConf = getShuffleServerConf(serverType);
7980
File dataDir1 = new File(tmpDir, id + "_1");
8081
File dataDir2 = new File(tmpDir, id + "_2");
8182
String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath();
@@ -85,46 +86,70 @@ public static MockedShuffleServer createMockedShuffleServer(int id, File tmpDir)
8586
shuffleServerConf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE, 15.0);
8687
shuffleServerConf.set(ShuffleServerConf.SERVER_BUFFER_CAPACITY, 600L);
8788
shuffleServerConf.set(ShuffleServerConf.SINGLE_BUFFER_FLUSH_BLOCKS_NUM_THRESHOLD, 1);
89+
shuffleServerConf.set(ShuffleServerConf.RPC_SERVER_PORT, 0);
8890
return new MockedShuffleServer(shuffleServerConf);
8991
}
9092

91-
@BeforeAll
92-
public static void initCluster(@TempDir File tmpDir) throws Exception {
93+
@BeforeEach
94+
public void initCluster(@TempDir File tmpDir) throws Exception {
9395
CoordinatorConf coordinatorConf = getCoordinatorConf();
9496
createCoordinatorServer(coordinatorConf);
97+
for (int i = 0; i < 1; i++) {
98+
grpcShuffleServers.add(createMockedShuffleServer(i, tmpDir, ServerType.GRPC));
99+
}
95100

96-
grpcShuffleServers.add(createMockedShuffleServer(0, tmpDir));
97-
grpcShuffleServers.add(createMockedShuffleServer(1, tmpDir));
98-
grpcShuffleServers.add(createMockedShuffleServer(2, tmpDir));
101+
startServers();
99102

100-
shuffleServerInfo0 =
101-
new ShuffleServerInfo(
102-
String.format("127.0.0.1-%s", grpcShuffleServers.get(0).getGrpcPort()),
103-
grpcShuffleServers.get(0).getIp(),
104-
grpcShuffleServers.get(0).getGrpcPort());
105-
shuffleServerInfo1 =
106-
new ShuffleServerInfo(
107-
String.format("127.0.0.1-%s", grpcShuffleServers.get(1).getGrpcPort()),
108-
grpcShuffleServers.get(1).getIp(),
109-
grpcShuffleServers.get(1).getGrpcPort());
110-
shuffleServerInfo2 =
111-
new ShuffleServerInfo(
112-
String.format("127.0.0.1-%s", grpcShuffleServers.get(2).getGrpcPort()),
113-
grpcShuffleServers.get(2).getIp(),
114-
grpcShuffleServers.get(2).getGrpcPort());
115-
for (CoordinatorServer coordinator : coordinators) {
116-
coordinator.start();
117-
}
118-
for (ShuffleServer shuffleServer : grpcShuffleServers) {
119-
shuffleServer.start();
103+
for (int i = 0; i < 1; i++) {
104+
grpcShuffleServerInfoList.add(
105+
new ShuffleServerInfo(
106+
String.format("127.0.0.1-%s", grpcShuffleServers.get(i).getGrpcPort()),
107+
grpcShuffleServers.get(i).getIp(),
108+
grpcShuffleServers.get(i).getGrpcPort()));
120109
}
121110
}
122111

123-
@AfterAll
124-
public static void cleanEnv() throws Exception {
112+
@Test
113+
public void testCancelGrpc() throws InterruptedException {
114+
String testAppId = "testCancelGrpc";
115+
registerShuffleServer(testAppId, 1, 1, 1, false, 3000);
116+
Map<Long, byte[]> expectedData = Maps.newHashMap();
117+
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
118+
List<ShuffleBlockInfo> blocks =
119+
createShuffleBlockList(
120+
0,
121+
0,
122+
0,
123+
2,
124+
25,
125+
blockIdBitmap,
126+
expectedData,
127+
Lists.newArrayList(grpcShuffleServerInfoList.get(0)));
128+
AtomicBoolean isCancel = new AtomicBoolean(false);
129+
Supplier<Boolean> needCancelRequest = () -> isCancel.get();
130+
SendShuffleDataResult result =
131+
shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, needCancelRequest);
132+
assertEquals(2, result.getSuccessBlockIds().size());
133+
134+
enableFirstNSendDataRequestsToFail(2);
135+
System.out.println("1====================");
136+
CompletableFuture<SendShuffleDataResult> future =
137+
CompletableFuture.supplyAsync(
138+
() -> shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, needCancelRequest));
139+
// this ensure isCancel takes effect in rpc retry
140+
TimeUnit.SECONDS.sleep(1);
141+
isCancel.set(true);
142+
Awaitility.await()
143+
.atMost(5, TimeUnit.SECONDS)
144+
.until(() -> future.isDone() && future.get().getSuccessBlockIds().size() == 0);
145+
}
146+
147+
@AfterEach
148+
public void cleanEnv() throws Exception {
125149
if (shuffleWriteClientImpl != null) {
126150
shuffleWriteClientImpl.close();
127151
}
152+
grpcShuffleServerInfoList.clear();
128153
shutdownServers();
129154
}
130155

@@ -140,22 +165,16 @@ private static Stream<Arguments> testRpcRetryLogicProvider() {
140165
@MethodSource("testRpcRetryLogicProvider")
141166
public void testRpcRetryLogic(StorageType storageType) {
142167
String testAppId = "testRpcRetryLogic";
143-
registerShuffleServer(testAppId, 3, 2, 2, true);
168+
registerShuffleServer(testAppId, 3, 2, 2, true, 1000);
144169
Map<Long, byte[]> expectedData = Maps.newHashMap();
145170
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
146171

147172
List<ShuffleBlockInfo> blocks =
148173
createShuffleBlockList(
149-
0,
150-
0,
151-
0,
152-
3,
153-
25,
154-
blockIdBitmap,
155-
expectedData,
156-
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2));
174+
0, 0, 0, 3, 25, blockIdBitmap, expectedData, grpcShuffleServerInfoList);
157175

158-
SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks);
176+
SendShuffleDataResult result =
177+
shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, () -> false);
159178
Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf();
160179
Roaring64NavigableMap successfulBlockIdBitmap = Roaring64NavigableMap.bitmapOf();
161180
for (Long blockId : result.getSuccessBlockIds()) {
@@ -174,8 +193,7 @@ public void testRpcRetryLogic(StorageType storageType) {
174193
.appId(testAppId)
175194
.blockIdBitmap(blockIdBitmap)
176195
.taskIdBitmap(taskIdBitmap)
177-
.shuffleServerInfoList(
178-
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2))
196+
.shuffleServerInfoList(grpcShuffleServerInfoList)
179197
.retryMax(3)
180198
.retryIntervalMax(1)
181199
.build();
@@ -195,8 +213,7 @@ public void testRpcRetryLogic(StorageType storageType) {
195213
.appId(testAppId)
196214
.blockIdBitmap(blockIdBitmap)
197215
.taskIdBitmap(taskIdBitmap)
198-
.shuffleServerInfoList(
199-
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2))
216+
.shuffleServerInfoList(grpcShuffleServerInfoList)
200217
.retryMax(3)
201218
.retryIntervalMax(1)
202219
.build();
@@ -208,39 +225,55 @@ public void testRpcRetryLogic(StorageType storageType) {
208225
}
209226

210227
private static void enableFirstNReadRequestsToFail(int failedCount) {
211-
for (ShuffleServer server : grpcShuffleServers) {
212-
((MockedGrpcServer) server.getServer())
213-
.getService()
214-
.enableFirstNReadRequestToFail(failedCount);
215-
}
228+
Lists.newArrayList(grpcShuffleServers, nettyShuffleServers).stream()
229+
.flatMap(List::stream)
230+
.forEach(
231+
server ->
232+
((MockedGrpcServer) server.getServer())
233+
.getService()
234+
.enableFirstNReadRequestToFail(failedCount));
235+
}
236+
237+
private static void enableFirstNSendDataRequestsToFail(int failedCount) {
238+
Lists.newArrayList(grpcShuffleServers, nettyShuffleServers).stream()
239+
.flatMap(List::stream)
240+
.forEach(
241+
server ->
242+
((MockedGrpcServer) server.getServer())
243+
.getService()
244+
.enableFirstNSendDataRequestToFail(failedCount));
216245
}
217246

218247
private static void disableFirstNReadRequestsToFail() {
219-
for (ShuffleServer server : grpcShuffleServers) {
220-
((MockedGrpcServer) server.getServer()).getService().resetFirstNReadRequestToFail();
221-
}
248+
Lists.newArrayList(grpcShuffleServers, nettyShuffleServers).stream()
249+
.flatMap(List::stream)
250+
.forEach(
251+
server ->
252+
((MockedGrpcServer) server.getServer())
253+
.getService()
254+
.resetFirstNReadRequestToFail());
222255
}
223256

224257
static class MockedShuffleWriteClientImpl extends ShuffleWriteClientImpl {
225258
MockedShuffleWriteClientImpl(ShuffleClientFactory.WriteClientBuilder builder) {
226259
super(builder);
227260
}
228-
229-
public SendShuffleDataResult sendShuffleData(
230-
String appId, List<ShuffleBlockInfo> shuffleBlockInfoList) {
231-
return super.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
232-
}
233261
}
234262

235263
private void registerShuffleServer(
236-
String testAppId, int replica, int replicaWrite, int replicaRead, boolean replicaSkip) {
264+
String testAppId,
265+
int replica,
266+
int replicaWrite,
267+
int replicaRead,
268+
boolean replicaSkip,
269+
long retryIntervalMs) {
237270

238271
shuffleWriteClientImpl =
239272
new MockedShuffleWriteClientImpl(
240273
ShuffleClientFactory.newWriteBuilder()
241274
.clientType(ClientType.GRPC.name())
242275
.retryMax(3)
243-
.retryIntervalMax(1000)
276+
.retryIntervalMax(retryIntervalMs)
244277
.heartBeatThreadNum(1)
245278
.replica(replica)
246279
.replicaWrite(replicaWrite)
@@ -252,12 +285,9 @@ private void registerShuffleServer(
252285
.unregisterTimeSec(10)
253286
.unregisterRequestTimeSec(10));
254287

255-
List<ShuffleServerInfo> allServers =
256-
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2);
257-
258288
for (int i = 0; i < replica; i++) {
259289
shuffleWriteClientImpl.registerShuffle(
260-
allServers.get(i),
290+
grpcShuffleServerInfoList.get(i),
261291
testAppId,
262292
0,
263293
Lists.newArrayList(new PartitionRange(0, 0)),

internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,10 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
647647
null,
648648
request.getRetryIntervalMax(),
649649
maxRetryAttempts,
650-
t -> !(t instanceof OutOfMemoryError) && !(t instanceof NotRetryException));
650+
t ->
651+
!request.needCancel()
652+
&& !(t instanceof OutOfMemoryError)
653+
&& !(t instanceof NotRetryException));
651654
} catch (Throwable throwable) {
652655
LOG.warn("Failed to send shuffle data due to ", throwable);
653656
isSuccessful = false;

internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ
235235
null,
236236
request.getRetryIntervalMax(),
237237
maxRetryAttempts,
238-
t -> !(t instanceof OutOfMemoryError) && !(t instanceof NotRetryException));
238+
t ->
239+
!request.needCancel()
240+
&& !(t instanceof OutOfMemoryError)
241+
&& !(t instanceof NotRetryException));
239242
} catch (Throwable throwable) {
240243
LOG.warn("Failed to send shuffle data due to ", throwable);
241244
isSuccessful = false;

internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.function.Supplier;
2223

2324
import org.apache.uniffle.common.ShuffleBlockInfo;
2425

@@ -29,26 +30,29 @@ public class RssSendShuffleDataRequest {
2930
private int retryMax;
3031
private long retryIntervalMax;
3132
private Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks;
33+
private Supplier<Boolean> needCancel;
3234

3335
public RssSendShuffleDataRequest(
3436
String appId,
3537
int retryMax,
3638
long retryIntervalMax,
3739
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
38-
this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks);
40+
this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks, () -> false);
3941
}
4042

4143
public RssSendShuffleDataRequest(
4244
String appId,
4345
int stageAttemptNumber,
4446
int retryMax,
4547
long retryIntervalMax,
46-
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks) {
48+
Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks,
49+
Supplier<Boolean> needCancel) {
4750
this.appId = appId;
4851
this.retryMax = retryMax;
4952
this.retryIntervalMax = retryIntervalMax;
5053
this.shuffleIdToBlocks = shuffleIdToBlocks;
5154
this.stageAttemptNumber = stageAttemptNumber;
55+
this.needCancel = needCancel;
5256
}
5357

5458
public String getAppId() {
@@ -70,4 +74,8 @@ public int getStageAttemptNumber() {
7074
public Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> getShuffleIdToBlocks() {
7175
return shuffleIdToBlocks;
7276
}
77+
78+
public Boolean needCancel() {
79+
return needCancel.get();
80+
}
7381
}

server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public class MockedShuffleServerGrpcService extends ShuffleServerGrpcService {
4545

4646
private boolean mockSendDataFailed = false;
4747
private int mockSendDataFailedStageNumber = -1;
48+
private AtomicInteger failedSendDataRequest = new AtomicInteger(0);
4849

4950
private boolean mockRequireBufferFailedWithNoBuffer = false;
5051
private boolean isMockRequireBufferFailedWithNoBufferForHugePartition = false;
@@ -86,6 +87,10 @@ public void enableFirstNReadRequestToFail(int n) {
8687
numOfFailedReadRequest = n;
8788
}
8889

90+
public void enableFirstNSendDataRequestToFail(int n) {
91+
failedSendDataRequest.set(n);
92+
}
93+
8994
public void resetFirstNReadRequestToFail() {
9095
numOfFailedReadRequest = 0;
9196
failedGetShuffleResultRequest.set(0);
@@ -146,6 +151,10 @@ public void sendShuffleData(
146151
mockSendDataFailedStageNumber);
147152
throw new RuntimeException("This write request is failed as mocked failure!");
148153
}
154+
if (failedSendDataRequest.getAndDecrement() > 0) {
155+
LOG.info("This request is failed as mocked failure");
156+
throw new RuntimeException("This write request is failed as mocked failure!");
157+
}
149158
if (mockedTimeout > 0) {
150159
LOG.info("Add a mocked timeout on sendShuffleData");
151160
Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);

0 commit comments

Comments
 (0)