Skip to content

Commit 96f29a1

Browse files
committed
Add TlsChannel stream support
JAVA-3038
1 parent 07d2e80 commit 96f29a1

File tree

3 files changed

+601
-0
lines changed

3 files changed

+601
-0
lines changed

driver-core/src/main/com/mongodb/MongoSocketOpenException.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,15 @@ public class MongoSocketOpenException extends MongoSocketException {
3434
public MongoSocketOpenException(final String message, final ServerAddress address, final Throwable cause) {
3535
super(message, address, cause);
3636
}
37+
38+
/**
39+
* Construct an instance.
40+
*
41+
* @param message the message
42+
* @param address the server address
43+
* @since 3.10
44+
*/
45+
public MongoSocketOpenException(final String message, final ServerAddress address) {
46+
super(message, address);
47+
}
3748
}
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.connection;
18+
19+
import com.mongodb.MongoClientException;
20+
import com.mongodb.MongoSocketOpenException;
21+
import com.mongodb.ServerAddress;
22+
import com.mongodb.diagnostics.logging.Logger;
23+
import com.mongodb.diagnostics.logging.Loggers;
24+
import com.mongodb.internal.connection.AsynchronousChannelStream;
25+
import com.mongodb.internal.connection.ConcurrentLinkedDeque;
26+
import com.mongodb.internal.connection.PowerOfTwoBufferPool;
27+
import com.mongodb.internal.connection.tlschannel.BufferAllocator;
28+
import com.mongodb.internal.connection.tlschannel.ClientTlsChannel;
29+
import com.mongodb.internal.connection.tlschannel.TlsChannel;
30+
import com.mongodb.internal.connection.tlschannel.async.AsynchronousTlsChannel;
31+
import com.mongodb.internal.connection.tlschannel.async.AsynchronousTlsChannelGroup;
32+
import org.bson.ByteBuf;
33+
34+
import javax.net.ssl.SSLContext;
35+
import javax.net.ssl.SSLEngine;
36+
import javax.net.ssl.SSLParameters;
37+
import java.io.Closeable;
38+
import java.io.IOException;
39+
import java.net.StandardSocketOptions;
40+
import java.nio.channels.SelectionKey;
41+
import java.nio.channels.Selector;
42+
import java.nio.channels.SocketChannel;
43+
import java.security.NoSuchAlgorithmException;
44+
import java.util.Iterator;
45+
46+
import static com.mongodb.assertions.Assertions.isTrue;
47+
import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
48+
import static com.mongodb.internal.connection.SslHelper.enableSni;
49+
50+
/**
51+
* A {@code StreamFactoryFactory} that supports TLS/SSL. The implementation supports asynchronous usage.
52+
* <p>
53+
* Requires Java 8
54+
* </p>
55+
*
56+
* @since 3.10
57+
*/
58+
public class TlsChannelStreamFactoryFactory implements StreamFactoryFactory, Closeable {
59+
60+
private static final Logger LOGGER = Loggers.getLogger("connection.tls");
61+
62+
private final SelectorMonitor selectorMonitor;
63+
private final AsynchronousTlsChannelGroup group;
64+
private final boolean ownsGroup;
65+
private final PowerOfTwoBufferPool bufferPool = new PowerOfTwoBufferPool();
66+
67+
/**
68+
* Construct a new instance
69+
*/
70+
public TlsChannelStreamFactoryFactory() {
71+
this(new AsynchronousTlsChannelGroup(), true);
72+
}
73+
74+
/**
75+
* Construct a new instance with the given {@code AsynchronousTlsChannelGroup}. Callers are required to close the provided group
76+
* in order to free up resources.
77+
*
78+
* @param group the group
79+
*/
80+
public TlsChannelStreamFactoryFactory(final AsynchronousTlsChannelGroup group) {
81+
this(group, false);
82+
}
83+
84+
private TlsChannelStreamFactoryFactory(final AsynchronousTlsChannelGroup group, final boolean ownsGroup) {
85+
this.group = group;
86+
this.ownsGroup = ownsGroup;
87+
selectorMonitor = new SelectorMonitor();
88+
selectorMonitor.start();
89+
}
90+
91+
@Override
92+
public StreamFactory create(final SocketSettings socketSettings, final SslSettings sslSettings) {
93+
return new StreamFactory() {
94+
@Override
95+
public Stream create(final ServerAddress serverAddress) {
96+
return new TlsChannelStream(serverAddress, socketSettings, sslSettings, bufferPool, group, selectorMonitor);
97+
}
98+
};
99+
}
100+
101+
@Override
102+
public void close() {
103+
selectorMonitor.close();
104+
if (ownsGroup) {
105+
group.shutdown();
106+
}
107+
}
108+
109+
private static class SelectorMonitor implements Closeable {
110+
111+
private static final class Pair {
112+
private final SocketChannel socketChannel;
113+
private final Runnable attachment;
114+
115+
private Pair(final SocketChannel socketChannel, final Runnable attachment) {
116+
this.socketChannel = socketChannel;
117+
this.attachment = attachment;
118+
}
119+
}
120+
121+
private final Selector selector;
122+
private volatile boolean isClosed;
123+
private final ConcurrentLinkedDeque<Pair> pendingRegistrations = new ConcurrentLinkedDeque<Pair>();
124+
125+
SelectorMonitor() {
126+
try {
127+
this.selector = Selector.open();
128+
} catch (IOException e) {
129+
throw new MongoClientException("Exception opening Selector", e);
130+
}
131+
}
132+
133+
void start() {
134+
Thread selectorThread = new Thread(new Runnable() {
135+
@Override
136+
public void run() {
137+
try {
138+
while (!isClosed) {
139+
try {
140+
selector.select();
141+
142+
for (SelectionKey selectionKey : selector.selectedKeys()) {
143+
selectionKey.cancel();
144+
Runnable runnable = (Runnable) selectionKey.attachment();
145+
runnable.run();
146+
}
147+
148+
for (Iterator<Pair> iter = pendingRegistrations.iterator(); iter.hasNext();) {
149+
Pair pendingRegistration = iter.next();
150+
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT,
151+
pendingRegistration.attachment);
152+
iter.remove();
153+
}
154+
} catch (IOException e) {
155+
LOGGER.warn("Exception in selector loop", e);
156+
} catch (RuntimeException e) {
157+
LOGGER.warn("Exception in selector loop", e);
158+
}
159+
}
160+
} finally {
161+
try {
162+
selector.close();
163+
} catch (IOException e) {
164+
// ignore
165+
}
166+
}
167+
}
168+
});
169+
selectorThread.setDaemon(true);
170+
selectorThread.start();
171+
}
172+
173+
void register(final SocketChannel channel, final Runnable attachment) {
174+
pendingRegistrations.add(new Pair(channel, attachment));
175+
selector.wakeup();
176+
}
177+
178+
@Override
179+
public void close() {
180+
isClosed = true;
181+
selector.wakeup();
182+
}
183+
}
184+
185+
private static class TlsChannelStream extends AsynchronousChannelStream implements Stream {
186+
187+
private final AsynchronousTlsChannelGroup group;
188+
private final SelectorMonitor selectorMonitor;
189+
private final SslSettings sslSettings;
190+
191+
TlsChannelStream(final ServerAddress serverAddress, final SocketSettings settings, final SslSettings sslSettings,
192+
final BufferProvider bufferProvider, final AsynchronousTlsChannelGroup group,
193+
final SelectorMonitor selectorMonitor) {
194+
super(serverAddress, settings, bufferProvider);
195+
this.sslSettings = sslSettings;
196+
this.group = group;
197+
this.selectorMonitor = selectorMonitor;
198+
}
199+
200+
@Override
201+
public void openAsync(final AsyncCompletionHandler<Void> handler) {
202+
isTrue("unopened", getChannel() == null);
203+
try {
204+
final SocketChannel socketChannel = SocketChannel.open();
205+
socketChannel.configureBlocking(false);
206+
207+
socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
208+
socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
209+
if (getSettings().getReceiveBufferSize() > 0) {
210+
socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
211+
}
212+
if (getSettings().getSendBufferSize() > 0) {
213+
socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
214+
}
215+
216+
socketChannel.connect(getServerAddress().getSocketAddress());
217+
218+
selectorMonitor.register(socketChannel, new Runnable() {
219+
@Override
220+
public void run() {
221+
try {
222+
if (!socketChannel.finishConnect()) {
223+
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
224+
}
225+
226+
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
227+
getServerAddress().getPort());
228+
sslEngine.setUseClientMode(true);
229+
230+
SSLParameters sslParameters = sslEngine.getSSLParameters();
231+
enableSni(getServerAddress().getHost(), sslParameters);
232+
233+
if (!sslSettings.isInvalidHostNameAllowed()) {
234+
enableHostNameVerification(sslParameters);
235+
}
236+
sslEngine.setSSLParameters(sslParameters);
237+
238+
BufferAllocator bufferAllocator = new BufferProviderAllocator();
239+
240+
TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
241+
.withEncryptedBufferAllocator(bufferAllocator)
242+
.withPlainBufferAllocator(bufferAllocator)
243+
.build();
244+
245+
// build asynchronous channel, based in the TLS channel and associated with the global group.
246+
setChannel(new AsynchronousTlsChannel(group, tlsChannel, socketChannel));
247+
248+
handler.completed(null);
249+
} catch (IOException e) {
250+
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
251+
} catch (Throwable t) {
252+
handler.failed(t);
253+
}
254+
}
255+
});
256+
} catch (IOException e) {
257+
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
258+
} catch (Throwable t) {
259+
handler.failed(t);
260+
}
261+
}
262+
263+
private SSLContext getSslContext() {
264+
try {
265+
return (sslSettings.getContext() == null) ? SSLContext.getDefault() : sslSettings.getContext();
266+
} catch (NoSuchAlgorithmException e) {
267+
throw new MongoClientException("Unable to create default SSLContext", e);
268+
}
269+
}
270+
271+
private class BufferProviderAllocator implements BufferAllocator {
272+
@Override
273+
public ByteBuf allocate(final int size) {
274+
return getBufferProvider().getBuffer(size);
275+
}
276+
277+
@Override
278+
public void free(final ByteBuf buffer) {
279+
buffer.release();
280+
}
281+
}
282+
}
283+
}
284+

0 commit comments

Comments
 (0)