Skip to content

Commit 4a9918b

Browse files
committed
Add async nonblocking ssl support in java client
1 parent b8f7e5b commit 4a9918b

File tree

5 files changed

+415
-10
lines changed

5 files changed

+415
-10
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.thrift.transport;
21+
22+
import java.io.IOException;
23+
import java.nio.ByteBuffer;
24+
import java.nio.channels.SelectionKey;
25+
import java.nio.channels.Selector;
26+
import javax.net.ssl.SSLContext;
27+
import javax.net.ssl.SSLEngine;
28+
import javax.net.ssl.SSLEngineResult;
29+
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
30+
import javax.net.ssl.SSLException;
31+
import org.slf4j.Logger;
32+
import org.slf4j.LoggerFactory;
33+
34+
/** Transport for use with async ssl client. */
35+
public class TNonblockingSSLSocket extends TNonblockingSocket implements SocketAddressProvider {
36+
37+
private static final Logger LOGGER =
38+
LoggerFactory.getLogger(TNonblockingSSLSocket.class.getName());
39+
40+
private final SSLEngine sslEngine_;
41+
42+
private final ByteBuffer appUnwrap;
43+
private final ByteBuffer netUnwrap;
44+
45+
private final ByteBuffer appWrap;
46+
private final ByteBuffer netWrap;
47+
48+
private ByteBuffer decodedBytes;
49+
50+
private boolean isHandshakeCompleted;
51+
52+
private SelectionKey selectionKey;
53+
54+
protected TNonblockingSSLSocket(String host, int port, int timeout, SSLContext sslContext)
55+
throws IOException, TTransportException {
56+
super(host, port, timeout);
57+
sslEngine_ = sslContext.createSSLEngine(host, port);
58+
sslEngine_.setUseClientMode(true);
59+
60+
int appBufferSize = sslEngine_.getSession().getApplicationBufferSize();
61+
int netBufferSize = sslEngine_.getSession().getPacketBufferSize();
62+
appUnwrap = ByteBuffer.allocate(appBufferSize);
63+
netUnwrap = ByteBuffer.allocate(netBufferSize);
64+
appWrap = ByteBuffer.allocate(appBufferSize);
65+
netWrap = ByteBuffer.allocate(netBufferSize);
66+
decodedBytes = ByteBuffer.allocate(appBufferSize);
67+
decodedBytes.flip();
68+
isHandshakeCompleted = false;
69+
}
70+
71+
/** {@inheritDoc} */
72+
@Override
73+
public SelectionKey registerSelector(Selector selector, int interests) throws IOException {
74+
selectionKey = super.registerSelector(selector, interests);
75+
return selectionKey;
76+
}
77+
78+
/** {@inheritDoc} */
79+
@Override
80+
public boolean isOpen() {
81+
// isConnected() does not return false after close(), but isOpen() does
82+
return super.isOpen() && isHandshakeCompleted;
83+
}
84+
85+
/** {@inheritDoc} */
86+
@Override
87+
public void open() throws TTransportException {
88+
throw new RuntimeException("open() is not implemented for TNonblockingSSLSocket");
89+
}
90+
91+
/** {@inheritDoc} */
92+
@Override
93+
public synchronized int read(ByteBuffer buffer) throws TTransportException {
94+
int numBytes = buffer.limit();
95+
while (decodedBytes.remaining() < numBytes) {
96+
try {
97+
if (doUnwrap() == -1) {
98+
throw new IOException("Unable to read " + numBytes + " bytes");
99+
}
100+
} catch (IOException exc) {
101+
throw new TTransportException(TTransportException.UNKNOWN, exc.getMessage());
102+
}
103+
if (appUnwrap.position() > 0) {
104+
int t;
105+
appUnwrap.flip();
106+
if (decodedBytes.position() > 0) decodedBytes.flip();
107+
t = appUnwrap.limit() + decodedBytes.limit();
108+
byte[] tmpBuffer = new byte[t];
109+
decodedBytes.get(tmpBuffer, 0, decodedBytes.remaining());
110+
appUnwrap.get(tmpBuffer, 0, appUnwrap.remaining());
111+
if (appUnwrap.position() > 0) {
112+
appUnwrap.clear();
113+
appUnwrap.flip();
114+
appUnwrap.compact();
115+
}
116+
decodedBytes = ByteBuffer.wrap(tmpBuffer);
117+
}
118+
}
119+
byte[] b = new byte[numBytes];
120+
decodedBytes.get(b, 0, numBytes);
121+
if (decodedBytes.position() > 0) {
122+
decodedBytes.compact();
123+
decodedBytes.flip();
124+
}
125+
buffer.put(b);
126+
selectionKey.interestOps(SelectionKey.OP_WRITE);
127+
return numBytes;
128+
}
129+
130+
/** {@inheritDoc} */
131+
@Override
132+
public synchronized int write(ByteBuffer buffer) throws TTransportException {
133+
int numBytes = 0;
134+
135+
if (buffer.position() > 0) buffer.flip();
136+
137+
int nTransfer;
138+
int num;
139+
while (buffer.remaining() != 0) {
140+
nTransfer = Math.min(appWrap.remaining(), buffer.remaining());
141+
if (nTransfer > 0) {
142+
appWrap.put(buffer.array(), buffer.arrayOffset() + buffer.position(), nTransfer);
143+
buffer.position(buffer.position() + nTransfer);
144+
}
145+
146+
try {
147+
num = doWrap();
148+
} catch (IOException iox) {
149+
throw new TTransportException(TTransportException.UNKNOWN, iox);
150+
}
151+
if (num < 0) {
152+
LOGGER.error("Failed while writing. Probably server is down");
153+
return -1;
154+
}
155+
numBytes += num;
156+
}
157+
return numBytes;
158+
}
159+
160+
/** {@inheritDoc} */
161+
@Override
162+
public void close() {
163+
sslEngine_.closeOutbound();
164+
super.close();
165+
}
166+
167+
/** {@inheritDoc} */
168+
@Override
169+
public boolean startConnect() throws IOException {
170+
if (this.isOpen()) {
171+
return true;
172+
}
173+
sslEngine_.beginHandshake();
174+
return super.startConnect() && doHandShake();
175+
}
176+
177+
/** {@inheritDoc} */
178+
@Override
179+
public boolean finishConnect() throws IOException {
180+
return super.finishConnect() && doHandShake();
181+
}
182+
183+
private synchronized boolean doHandShake() throws IOException {
184+
LOGGER.debug("Handshake is started");
185+
while (true) {
186+
HandshakeStatus hs = sslEngine_.getHandshakeStatus();
187+
switch (hs) {
188+
case NEED_UNWRAP:
189+
if (doUnwrap() == -1) {
190+
LOGGER.error("Unexpected. Handshake failed abruptly during unwrap");
191+
return false;
192+
}
193+
break;
194+
case NEED_WRAP:
195+
if (doWrap() == -1) {
196+
LOGGER.error("Unexpected. Handshake failed abruptly during wrap");
197+
return false;
198+
}
199+
break;
200+
case NEED_TASK:
201+
if (!doTask()) {
202+
LOGGER.error("Unexpected. Handshake failed abruptly during task");
203+
return false;
204+
}
205+
break;
206+
case FINISHED:
207+
case NOT_HANDSHAKING:
208+
isHandshakeCompleted = true;
209+
return true;
210+
default:
211+
LOGGER.error("Unknown handshake status. Handshake failed");
212+
return false;
213+
}
214+
}
215+
}
216+
217+
private synchronized boolean doTask() {
218+
Runnable runnable;
219+
while ((runnable = sslEngine_.getDelegatedTask()) != null) {
220+
runnable.run();
221+
}
222+
HandshakeStatus hs = sslEngine_.getHandshakeStatus();
223+
return hs != HandshakeStatus.NEED_TASK;
224+
}
225+
226+
private synchronized int doUnwrap() throws IOException {
227+
int num = getSocketChannel().read(netUnwrap);
228+
if (num < 0) {
229+
LOGGER.error("Failed during read operation. Probably server is down");
230+
return -1;
231+
}
232+
SSLEngineResult unwrapResult;
233+
234+
try {
235+
netUnwrap.flip();
236+
unwrapResult = sslEngine_.unwrap(netUnwrap, appUnwrap);
237+
netUnwrap.compact();
238+
} catch (SSLException ex) {
239+
LOGGER.error(ex.getMessage());
240+
throw ex;
241+
}
242+
243+
switch (unwrapResult.getStatus()) {
244+
case OK:
245+
if (appUnwrap.position() > 0) {
246+
appUnwrap.flip();
247+
appUnwrap.compact();
248+
}
249+
break;
250+
case CLOSED:
251+
return -1;
252+
case BUFFER_OVERFLOW:
253+
throw new IllegalStateException("Failed to unwrap");
254+
case BUFFER_UNDERFLOW:
255+
break;
256+
}
257+
return num;
258+
}
259+
260+
private synchronized int doWrap() throws IOException {
261+
int num = 0;
262+
SSLEngineResult wrapResult;
263+
try {
264+
appWrap.flip();
265+
wrapResult = sslEngine_.wrap(appWrap, netWrap);
266+
appWrap.compact();
267+
} catch (SSLException exc) {
268+
LOGGER.error(exc.getMessage());
269+
throw exc;
270+
}
271+
272+
switch (wrapResult.getStatus()) {
273+
case OK:
274+
if (netWrap.position() > 0) {
275+
netWrap.flip();
276+
num = getSocketChannel().write(netWrap);
277+
netWrap.compact();
278+
}
279+
break;
280+
case BUFFER_UNDERFLOW:
281+
// try again later
282+
break;
283+
case BUFFER_OVERFLOW:
284+
throw new IllegalStateException("Failed to wrap");
285+
case CLOSED:
286+
LOGGER.error("SSL session is closed");
287+
return -1;
288+
}
289+
return num;
290+
}
291+
}

lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,7 @@ public int read(byte[] buf, int off, int len) throws TTransportException {
148148
throw new TTransportException(
149149
TTransportException.NOT_OPEN, "Cannot read from write-only socket channel");
150150
}
151-
try {
152-
return socketChannel_.read(ByteBuffer.wrap(buf, off, len));
153-
} catch (IOException iox) {
154-
throw new TTransportException(TTransportException.UNKNOWN, iox);
155-
}
151+
return read(ByteBuffer.wrap(buf, off, len));
156152
}
157153

158154
/** Perform a nonblocking write of the data in buffer; */

lib/java/src/main/java/org/apache/thrift/transport/TSSLTransportFactory.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,28 @@ public static TSocket getClientSocket(
192192
return createClient(ctx.getSocketFactory(), host, port, timeout);
193193
}
194194

195+
/**
196+
* Get a custom configured TNonblockingTransport. The SSL settings are obtained from the passed in
197+
* TSSLTransportParameters.
198+
*
199+
* @param host
200+
* @param port
201+
* @param timeout
202+
* @param params
203+
* @return TNonblockingTransport
204+
* @throws TTransportException
205+
*/
206+
public static TNonblockingTransport getNonblockingClientSocket(
207+
String host, int port, int timeout, TSSLTransportParameters params)
208+
throws TTransportException, IOException {
209+
if (params == null || !(params.isKeyStoreSet || params.isTrustStoreSet)) {
210+
throw new TTransportException(
211+
"Either one of the KeyStore or TrustStore must be set for SSLTransportParameters");
212+
}
213+
SSLContext ctx = createSSLContext(params);
214+
return new TNonblockingSSLSocket(host, port, timeout, ctx);
215+
}
216+
195217
private static SSLContext createSSLContext(TSSLTransportParameters params)
196218
throws TTransportException {
197219
SSLContext ctx;

lib/java/src/test/java/org/apache/thrift/async/TestTAsyncClientManager.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@
4040
import org.apache.thrift.server.ServerTestBase;
4141
import org.apache.thrift.server.THsHaServer;
4242
import org.apache.thrift.server.THsHaServer.Args;
43+
import org.apache.thrift.server.TServer;
4344
import org.apache.thrift.transport.TNonblockingServerSocket;
4445
import org.apache.thrift.transport.TNonblockingSocket;
46+
import org.apache.thrift.transport.TNonblockingTransport;
4547
import org.apache.thrift.transport.TTransportException;
4648
import org.junit.jupiter.api.AfterEach;
4749
import org.junit.jupiter.api.Assertions;
@@ -54,9 +56,9 @@
5456

5557
public class TestTAsyncClientManager {
5658

57-
private THsHaServer server_;
58-
private Thread serverThread_;
59-
private TAsyncClientManager clientManager_;
59+
protected TServer server_;
60+
protected Thread serverThread_;
61+
protected TAsyncClientManager clientManager_;
6062

6163
@BeforeEach
6264
public void setUp() throws Exception {
@@ -261,11 +263,14 @@ public void testParallelCalls() throws Exception {
261263
}
262264

263265
private Srv.AsyncClient getClient() throws IOException, TTransportException {
264-
TNonblockingSocket clientSocket =
265-
new TNonblockingSocket(ServerTestBase.HOST, ServerTestBase.PORT);
266+
TNonblockingTransport clientSocket = getClientTransport();
266267
return new Srv.AsyncClient(new TBinaryProtocol.Factory(), clientManager_, clientSocket);
267268
}
268269

270+
protected TNonblockingTransport getClientTransport() throws TTransportException, IOException {
271+
return new TNonblockingSocket(ServerTestBase.HOST, ServerTestBase.PORT);
272+
}
273+
269274
private void basicCall(Srv.AsyncClient client) throws Exception {
270275
final CountDownLatch latch = new CountDownLatch(1);
271276
final AtomicBoolean returned = new AtomicBoolean(false);

0 commit comments

Comments
 (0)