Skip to content

Commit 660fe53

Browse files
authored
RATIS-2325. Create GrpcStubPool for GrpcServerProtocolClient (#1306)
1 parent 81c714d commit 660fe53

File tree

4 files changed

+181
-4
lines changed

4 files changed

+181
-4
lines changed

ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,15 @@ static GrpcTlsConfig tlsConf(Parameters parameters) {
282282
static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) {
283283
parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS);
284284
}
285+
286+
String STUB_POOL_SIZE_KEY = PREFIX + ".stub.pool.size";
287+
int STUB_POOL_SIZE_DEFAULT = 1;
288+
static int stubPoolSize(RaftProperties properties) {
289+
return get(properties::getInt, STUB_POOL_SIZE_KEY, STUB_POOL_SIZE_DEFAULT, getDefaultLog());
290+
}
291+
static void setStubPoolSize(RaftProperties properties, int size) {
292+
setInt(properties::setInt, STUB_POOL_SIZE_KEY, size);
293+
}
285294
}
286295

287296
String MESSAGE_SIZE_MAX_KEY = PREFIX + ".message.size.max";

ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolClient.java

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
class GrpcServerProtocolClient implements Closeable {
4646
// Common channel
4747
private final ManagedChannel channel;
48+
private final GrpcStubPool<RaftServerProtocolServiceStub> pool;
4849
// Channel and stub for heartbeat
4950
private ManagedChannel hbChannel;
5051
private RaftServerProtocolServiceStub hbAsyncStub;
@@ -57,7 +58,7 @@ class GrpcServerProtocolClient implements Closeable {
5758
//visible for using in log / error messages AND to use in instrumented tests
5859
private final RaftPeerId raftPeerId;
5960

60-
GrpcServerProtocolClient(RaftPeer target, int flowControlWindow,
61+
GrpcServerProtocolClient(RaftPeer target, int connections, int flowControlWindow,
6162
TimeDuration requestTimeout, SslContext sslContext, boolean separateHBChannel) {
6263
raftPeerId = target.getId();
6364
LOG.info("Build channel for {}", target);
@@ -70,6 +71,11 @@ class GrpcServerProtocolClient implements Closeable {
7071
hbAsyncStub = RaftServerProtocolServiceGrpc.newStub(hbChannel);
7172
}
7273
requestTimeoutDuration = requestTimeout;
74+
this.pool = connections == 1? null : newGrpcStubPool(target.getAddress(), sslContext, connections);
75+
}
76+
77+
GrpcStubPool<RaftServerProtocolServiceStub> newGrpcStubPool(String address, SslContext sslContext, int connections) {
78+
return new GrpcStubPool<>(connections, address, sslContext, RaftServerProtocolServiceGrpc::newStub, 16);
7379
}
7480

7581
private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, SslContext sslContext) {
@@ -94,6 +100,9 @@ public void close() {
94100
GrpcUtil.shutdownManagedChannel(hbChannel);
95101
}
96102
GrpcUtil.shutdownManagedChannel(channel);
103+
if (pool != null) {
104+
pool.close();
105+
}
97106
}
98107

99108
public RequestVoteReplyProto requestVote(RequestVoteRequestProto request) {
@@ -112,8 +121,44 @@ public StartLeaderElectionReplyProto startLeaderElection(StartLeaderElectionRequ
112121
}
113122

114123
void readIndex(ReadIndexRequestProto request, StreamObserver<ReadIndexReplyProto> s) {
115-
asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit())
116-
.readIndex(request, s);
124+
if (pool == null) {
125+
asyncStub.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit())
126+
.readIndex(request, s);
127+
} else {
128+
GrpcStubPool.Stub<RaftServerProtocolServiceStub> p;
129+
try {
130+
p = pool.acquire();
131+
} catch (InterruptedException e) {
132+
Thread.currentThread().interrupt();
133+
s.onError(e);
134+
return;
135+
}
136+
p.getStub().withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit())
137+
.readIndex(request, new StreamObserver<ReadIndexReplyProto>() {
138+
@Override
139+
public void onNext(ReadIndexReplyProto v) {
140+
s.onNext(v);
141+
}
142+
143+
@Override
144+
public void onError(Throwable t) {
145+
try {
146+
s.onError(t);
147+
} finally {
148+
p.release();
149+
}
150+
}
151+
152+
@Override
153+
public void onCompleted() {
154+
try {
155+
s.onCompleted();
156+
} finally {
157+
p.release();
158+
}
159+
}
160+
});
161+
}
117162
}
118163

119164
CallStreamObserver<AppendEntriesRequestProto> appendEntries(

ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServicesImpl.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ public static final class Builder {
108108
private int serverPort;
109109
private SslContext serverSslContextForServer;
110110
private SslContext serverSslContextForClient;
111+
private int serverStubPoolSize;
111112

112113
private SizeInBytes messageSizeMax;
113114
private SizeInBytes flowControlWindow;
@@ -130,6 +131,7 @@ public Builder setServer(RaftServer raftServer) {
130131
this.flowControlWindow = GrpcConfigKeys.flowControlWindow(properties, LOG::info);
131132
this.requestTimeoutDuration = RaftServerConfigKeys.Rpc.requestTimeout(properties);
132133
this.separateHeartbeatChannel = GrpcConfigKeys.Server.heartbeatChannel(properties);
134+
this.serverStubPoolSize = GrpcConfigKeys.Server.stubPoolSize(properties);
133135

134136
final SizeInBytes appenderBufferSize = RaftServerConfigKeys.Log.Appender.bufferByteLimit(properties);
135137
final SizeInBytes gap = SizeInBytes.ONE_MB;
@@ -150,7 +152,7 @@ public Builder setCustomizer(Customizer customizer) {
150152
}
151153

152154
private GrpcServerProtocolClient newGrpcServerProtocolClient(RaftPeer target) {
153-
return new GrpcServerProtocolClient(target, flowControlWindow.getSizeInt(),
155+
return new GrpcServerProtocolClient(target, serverStubPoolSize, flowControlWindow.getSizeInt(),
154156
requestTimeoutDuration, serverSslContextForClient, separateHeartbeatChannel);
155157
}
156158

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.ratis.grpc.server;
19+
20+
import org.apache.ratis.thirdparty.io.grpc.ManagedChannel;
21+
import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType;
22+
import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder;
23+
import org.apache.ratis.thirdparty.io.grpc.stub.AbstractStub;
24+
import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption;
25+
import org.apache.ratis.thirdparty.io.netty.channel.WriteBufferWaterMark;
26+
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
27+
import org.apache.ratis.util.MemoizedSupplier;
28+
import org.apache.ratis.util.Preconditions;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
31+
32+
import java.util.ArrayList;
33+
import java.util.Collections;
34+
import java.util.List;
35+
import java.util.concurrent.Semaphore;
36+
import java.util.concurrent.ThreadLocalRandom;
37+
import java.util.concurrent.TimeUnit;
38+
import java.util.function.Function;
39+
40+
final class GrpcStubPool<S extends AbstractStub<S>> {
41+
public static final Logger LOG = LoggerFactory.getLogger(GrpcStubPool.class);
42+
43+
static ManagedChannel buildManagedChannel(String address, SslContext sslContext) {
44+
NettyChannelBuilder channelBuilder = NettyChannelBuilder.forTarget(address)
45+
.keepAliveTime(10, TimeUnit.MINUTES)
46+
.keepAliveWithoutCalls(false)
47+
.idleTimeout(30, TimeUnit.MINUTES)
48+
.withOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(64 << 10, 128 << 10));
49+
if (sslContext != null) {
50+
LOG.debug("Setting TLS for {}", address);
51+
channelBuilder.useTransportSecurity().sslContext(sslContext);
52+
} else {
53+
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
54+
}
55+
ManagedChannel ch = channelBuilder.build();
56+
ch.getState(true);
57+
return ch;
58+
}
59+
60+
static final class Stub<S extends AbstractStub<S>> {
61+
private final ManagedChannel ch;
62+
private final S stub;
63+
private final Semaphore permits;
64+
65+
Stub(String address, SslContext sslContext, Function<ManagedChannel, S> stubFactory, int maxInflight) {
66+
this.ch = buildManagedChannel(address, sslContext);
67+
this.stub = stubFactory.apply(ch);
68+
this.permits = new Semaphore(maxInflight);
69+
}
70+
71+
S getStub() {
72+
return stub;
73+
}
74+
75+
void release() {
76+
permits.release();
77+
}
78+
79+
void shutdown() {
80+
ch.shutdown();
81+
}
82+
}
83+
84+
private final List<MemoizedSupplier<Stub<S>>> pool;
85+
86+
GrpcStubPool(int connections, String address, SslContext sslContext, Function<ManagedChannel, S> stubFactory,
87+
int maxInflightPerConn) {
88+
Preconditions.assertTrue(connections > 1, "connections must be > 1");
89+
final List<MemoizedSupplier<Stub<S>>> tmpPool = new ArrayList<>(connections);
90+
for (int i = 0; i < connections; i++) {
91+
tmpPool.add(MemoizedSupplier.valueOf(() -> new Stub<>(address, sslContext, stubFactory, maxInflightPerConn)));
92+
}
93+
this.pool = Collections.unmodifiableList(tmpPool);
94+
}
95+
96+
Stub<S> getStub(int i) {
97+
return pool.get(i).get();
98+
}
99+
100+
Stub<S> acquire() throws InterruptedException {
101+
final int size = pool.size();
102+
final int start = ThreadLocalRandom.current().nextInt(size);
103+
for (int k = 0; k < size; k++) {
104+
Stub<S> p = getStub((start + k) % size);
105+
if (p.permits.tryAcquire()) {
106+
return p;
107+
}
108+
}
109+
final Stub<S> p = getStub(start);
110+
p.permits.acquire();
111+
return p;
112+
}
113+
114+
public void close() {
115+
for (MemoizedSupplier<Stub<S>> p : pool) {
116+
if (p.isInitialized()) {
117+
p.get().shutdown();
118+
}
119+
}
120+
}
121+
}

0 commit comments

Comments
 (0)