Skip to content

Commit 6267ee3

Browse files
committed
Async updates to work with tls-channel
JAVA-3588
1 parent 7b2344e commit 6267ee3

File tree

5 files changed

+136
-81
lines changed

5 files changed

+136
-81
lines changed

driver-core/src/main/com/mongodb/connection/AsynchronousSocketChannelStreamFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
* @since 3.0
3232
*/
3333
public class AsynchronousSocketChannelStreamFactory implements StreamFactory {
34-
private final BufferProvider bufferProvider = new PowerOfTwoBufferPool();
34+
private final PowerOfTwoBufferPool bufferProvider = new PowerOfTwoBufferPool();
3535
private final SocketSettings settings;
3636
private final AsynchronousChannelGroup group;
3737

driver-core/src/main/com/mongodb/connection/TlsChannelStreamFactoryFactory.java

Lines changed: 125 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,29 @@
2323
import com.mongodb.diagnostics.logging.Loggers;
2424
import com.mongodb.internal.connection.AsynchronousChannelStream;
2525
import com.mongodb.internal.connection.ConcurrentLinkedDeque;
26+
import com.mongodb.internal.connection.ExtendedAsynchronousByteChannel;
2627
import com.mongodb.internal.connection.PowerOfTwoBufferPool;
2728
import com.mongodb.internal.connection.tlschannel.BufferAllocator;
2829
import com.mongodb.internal.connection.tlschannel.ClientTlsChannel;
2930
import com.mongodb.internal.connection.tlschannel.TlsChannel;
3031
import com.mongodb.internal.connection.tlschannel.async.AsynchronousTlsChannel;
3132
import com.mongodb.internal.connection.tlschannel.async.AsynchronousTlsChannelGroup;
32-
import org.bson.ByteBuf;
3333

3434
import javax.net.ssl.SSLContext;
3535
import javax.net.ssl.SSLEngine;
3636
import javax.net.ssl.SSLParameters;
3737
import java.io.Closeable;
3838
import java.io.IOException;
3939
import java.net.StandardSocketOptions;
40+
import java.nio.ByteBuffer;
41+
import java.nio.channels.CompletionHandler;
4042
import java.nio.channels.SelectionKey;
4143
import java.nio.channels.Selector;
4244
import java.nio.channels.SocketChannel;
4345
import java.security.NoSuchAlgorithmException;
4446
import java.util.Iterator;
47+
import java.util.concurrent.Future;
48+
import java.util.concurrent.TimeUnit;
4549

4650
import static com.mongodb.assertions.Assertions.isTrue;
4751
import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
@@ -89,12 +93,7 @@ private TlsChannelStreamFactoryFactory(final AsynchronousTlsChannelGroup group,
8993

9094
@Override
9195
public StreamFactory create(final SocketSettings socketSettings, final SslSettings sslSettings) {
92-
return new StreamFactory() {
93-
@Override
94-
public Stream create(final ServerAddress serverAddress) {
95-
return new TlsChannelStream(serverAddress, socketSettings, sslSettings, bufferPool, group, selectorMonitor);
96-
}
97-
};
96+
return serverAddress -> new TlsChannelStream(serverAddress, socketSettings, sslSettings, bufferPool, group, selectorMonitor);
9897
}
9998

10099
@Override
@@ -119,7 +118,7 @@ private Pair(final SocketChannel socketChannel, final Runnable attachment) {
119118

120119
private final Selector selector;
121120
private volatile boolean isClosed;
122-
private final ConcurrentLinkedDeque<Pair> pendingRegistrations = new ConcurrentLinkedDeque<Pair>();
121+
private final ConcurrentLinkedDeque<Pair> pendingRegistrations = new ConcurrentLinkedDeque<>();
123122

124123
SelectorMonitor() {
125124
try {
@@ -130,39 +129,34 @@ private Pair(final SocketChannel socketChannel, final Runnable attachment) {
130129
}
131130

132131
void start() {
133-
Thread selectorThread = new Thread(new Runnable() {
134-
@Override
135-
public void run() {
136-
try {
137-
while (!isClosed) {
138-
try {
139-
selector.select();
140-
141-
for (SelectionKey selectionKey : selector.selectedKeys()) {
142-
selectionKey.cancel();
143-
Runnable runnable = (Runnable) selectionKey.attachment();
144-
runnable.run();
145-
}
146-
147-
for (Iterator<Pair> iter = pendingRegistrations.iterator(); iter.hasNext();) {
148-
Pair pendingRegistration = iter.next();
149-
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT,
150-
pendingRegistration.attachment);
151-
iter.remove();
152-
}
153-
} catch (IOException e) {
154-
LOGGER.warn("Exception in selector loop", e);
155-
} catch (RuntimeException e) {
156-
LOGGER.warn("Exception in selector loop", e);
157-
}
158-
}
159-
} finally {
132+
Thread selectorThread = new Thread(() -> {
133+
try {
134+
while (!isClosed) {
160135
try {
161-
selector.close();
162-
} catch (IOException e) {
163-
// ignore
136+
selector.select();
137+
138+
for (SelectionKey selectionKey : selector.selectedKeys()) {
139+
selectionKey.cancel();
140+
Runnable runnable = (Runnable) selectionKey.attachment();
141+
runnable.run();
142+
}
143+
144+
for (Iterator<Pair> iter = pendingRegistrations.iterator(); iter.hasNext();) {
145+
Pair pendingRegistration = iter.next();
146+
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT,
147+
pendingRegistration.attachment);
148+
iter.remove();
149+
}
150+
} catch (IOException | RuntimeException e) {
151+
LOGGER.warn("Exception in selector loop", e);
164152
}
165153
}
154+
} finally {
155+
try {
156+
selector.close();
157+
} catch (IOException e) {
158+
// ignore
159+
}
166160
}
167161
});
168162
selectorThread.setDaemon(true);
@@ -188,7 +182,7 @@ private static class TlsChannelStream extends AsynchronousChannelStream implemen
188182
private final SslSettings sslSettings;
189183

190184
TlsChannelStream(final ServerAddress serverAddress, final SocketSettings settings, final SslSettings sslSettings,
191-
final BufferProvider bufferProvider, final AsynchronousTlsChannelGroup group,
185+
final PowerOfTwoBufferPool bufferProvider, final AsynchronousTlsChannelGroup group,
192186
final SelectorMonitor selectorMonitor) {
193187
super(serverAddress, settings, bufferProvider);
194188
this.sslSettings = sslSettings;
@@ -219,42 +213,39 @@ public void openAsync(final AsyncCompletionHandler<Void> handler) {
219213

220214
socketChannel.connect(getServerAddress().getSocketAddress());
221215

222-
selectorMonitor.register(socketChannel, new Runnable() {
223-
@Override
224-
public void run() {
225-
try {
226-
if (!socketChannel.finishConnect()) {
227-
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
228-
}
216+
selectorMonitor.register(socketChannel, () -> {
217+
try {
218+
if (!socketChannel.finishConnect()) {
219+
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
220+
}
229221

230-
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
231-
getServerAddress().getPort());
232-
sslEngine.setUseClientMode(true);
222+
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
223+
getServerAddress().getPort());
224+
sslEngine.setUseClientMode(true);
233225

234-
SSLParameters sslParameters = sslEngine.getSSLParameters();
235-
enableSni(getServerAddress().getHost(), sslParameters);
226+
SSLParameters sslParameters = sslEngine.getSSLParameters();
227+
enableSni(getServerAddress().getHost(), sslParameters);
236228

237-
if (!sslSettings.isInvalidHostNameAllowed()) {
238-
enableHostNameVerification(sslParameters);
239-
}
240-
sslEngine.setSSLParameters(sslParameters);
229+
if (!sslSettings.isInvalidHostNameAllowed()) {
230+
enableHostNameVerification(sslParameters);
231+
}
232+
sslEngine.setSSLParameters(sslParameters);
241233

242-
BufferAllocator bufferAllocator = new BufferProviderAllocator();
234+
BufferAllocator bufferAllocator = new BufferProviderAllocator();
243235

244-
TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
245-
.withEncryptedBufferAllocator(bufferAllocator)
246-
.withPlainBufferAllocator(bufferAllocator)
247-
.build();
236+
TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
237+
.withEncryptedBufferAllocator(bufferAllocator)
238+
.withPlainBufferAllocator(bufferAllocator)
239+
.build();
248240

249-
// build asynchronous channel, based in the TLS channel and associated with the global group.
250-
setChannel(new AsynchronousTlsChannel(group, tlsChannel, socketChannel));
241+
// build asynchronous channel, based in the TLS channel and associated with the global group.
242+
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));
251243

252-
handler.completed(null);
253-
} catch (IOException e) {
254-
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
255-
} catch (Throwable t) {
256-
handler.failed(t);
257-
}
244+
handler.completed(null);
245+
} catch (IOException e) {
246+
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
247+
} catch (Throwable t) {
248+
handler.failed(t);
258249
}
259250
});
260251
} catch (IOException e) {
@@ -274,13 +265,75 @@ private SSLContext getSslContext() {
274265

275266
private class BufferProviderAllocator implements BufferAllocator {
276267
@Override
277-
public ByteBuf allocate(final int size) {
278-
return getBufferProvider().getBuffer(size);
268+
public ByteBuffer allocate(final int size) {
269+
return getBufferProvider().getByteBuffer(size);
270+
}
271+
272+
@Override
273+
public void free(final ByteBuffer buffer) {
274+
getBufferProvider().release(buffer);
275+
}
276+
}
277+
278+
public static class AsynchronousTlsChannelAdapter implements ExtendedAsynchronousByteChannel {
279+
private final AsynchronousTlsChannel wrapped;
280+
281+
AsynchronousTlsChannelAdapter(final AsynchronousTlsChannel wrapped) {
282+
this.wrapped = wrapped;
283+
}
284+
285+
@Override
286+
public <A> void read(final ByteBuffer dst, final A attach, final CompletionHandler<Integer, ? super A> handler) {
287+
wrapped.read(dst, attach, handler);
288+
}
289+
290+
@Override
291+
public <A> void read(final ByteBuffer dst, final long timeout, final TimeUnit unit, final A attach,
292+
final CompletionHandler<Integer, ? super A> handler) {
293+
wrapped.read(dst, timeout, unit, attach, handler);
294+
}
295+
296+
@Override
297+
public <A> void read(final ByteBuffer[] dsts, final int offset, final int length, final long timeout, final TimeUnit unit,
298+
final A attach, final CompletionHandler<Long, ? super A> handler) {
299+
wrapped.read(dsts, offset, length, timeout, unit, attach, handler);
300+
}
301+
302+
@Override
303+
public Future<Integer> read(final ByteBuffer dst) {
304+
return wrapped.read(dst);
305+
}
306+
307+
@Override
308+
public <A> void write(final ByteBuffer src, final A attach, final CompletionHandler<Integer, ? super A> handler) {
309+
wrapped.write(src, attach, handler);
310+
}
311+
312+
@Override
313+
public <A> void write(final ByteBuffer src, final long timeout, final TimeUnit unit, final A attach,
314+
final CompletionHandler<Integer, ? super A> handler) {
315+
wrapped.write(src, timeout, unit, attach, handler);
316+
}
317+
318+
@Override
319+
public <A> void write(final ByteBuffer[] srcs, final int offset, final int length, final long timeout, final TimeUnit unit,
320+
final A attach, final CompletionHandler<Long, ? super A> handler) {
321+
wrapped.write(srcs, offset, length, timeout, unit, attach, handler);
322+
}
323+
324+
@Override
325+
public Future<Integer> write(final ByteBuffer src) {
326+
return wrapped.write(src);
327+
}
328+
329+
@Override
330+
public boolean isOpen() {
331+
return wrapped.isOpen();
279332
}
280333

281334
@Override
282-
public void free(final ByteBuf buffer) {
283-
buffer.release();
335+
public void close() throws IOException {
336+
wrapped.close();
284337
}
285338
}
286339
}

driver-core/src/main/com/mongodb/internal/connection/AsynchronousChannelStream.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import com.mongodb.MongoSocketReadTimeoutException;
2424
import com.mongodb.ServerAddress;
2525
import com.mongodb.connection.AsyncCompletionHandler;
26-
import com.mongodb.connection.BufferProvider;
2726
import com.mongodb.connection.SocketSettings;
2827
import com.mongodb.connection.Stream;
2928
import org.bson.ByteBuf;
@@ -46,12 +45,12 @@
4645
public abstract class AsynchronousChannelStream implements Stream {
4746
private final ServerAddress serverAddress;
4847
private final SocketSettings settings;
49-
private final BufferProvider bufferProvider;
48+
private final PowerOfTwoBufferPool bufferProvider;
5049
private volatile ExtendedAsynchronousByteChannel channel;
5150
private volatile boolean isClosed;
5251

5352
public AsynchronousChannelStream(final ServerAddress serverAddress, final SocketSettings settings,
54-
final BufferProvider bufferProvider) {
53+
final PowerOfTwoBufferPool bufferProvider) {
5554
this.serverAddress = serverAddress;
5655
this.settings = settings;
5756
this.bufferProvider = bufferProvider;
@@ -65,7 +64,7 @@ public SocketSettings getSettings() {
6564
return settings;
6665
}
6766

68-
public BufferProvider getBufferProvider() {
67+
public PowerOfTwoBufferPool getBufferProvider() {
6968
return bufferProvider;
7069
}
7170

driver-core/src/main/com/mongodb/internal/connection/AsynchronousSocketChannelStream.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import com.mongodb.MongoSocketOpenException;
2121
import com.mongodb.ServerAddress;
2222
import com.mongodb.connection.AsyncCompletionHandler;
23-
import com.mongodb.connection.BufferProvider;
2423
import com.mongodb.connection.SocketSettings;
2524
import com.mongodb.connection.Stream;
2625

@@ -45,7 +44,7 @@ public final class AsynchronousSocketChannelStream extends AsynchronousChannelSt
4544
private final AsynchronousChannelGroup group;
4645

4746
public AsynchronousSocketChannelStream(final ServerAddress serverAddress, final SocketSettings settings,
48-
final BufferProvider bufferProvider, final AsynchronousChannelGroup group) {
47+
final PowerOfTwoBufferPool bufferProvider, final AsynchronousChannelGroup group) {
4948
super(serverAddress, settings, bufferProvider);
5049
this.serverAddress = serverAddress;
5150
this.settings = settings;

driver-core/src/main/com/mongodb/internal/connection/PowerOfTwoBufferPool.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,16 @@ public Prune shouldPrune(final ByteBuffer byteBuffer) {
7474

7575
@Override
7676
public ByteBuf getBuffer(final int size) {
77+
return new PooledByteBufNIO(getByteBuffer(size));
78+
}
79+
80+
public ByteBuffer getByteBuffer(final int size) {
7781
ConcurrentPool<ByteBuffer> pool = powerOfTwoToPoolMap.get(log2(roundUpToNextHighestPowerOfTwo(size)));
7882
ByteBuffer byteBuffer = (pool == null) ? createNew(size) : pool.get();
7983

8084
((Buffer) byteBuffer).clear();
8185
((Buffer) byteBuffer).limit(size);
82-
return new PooledByteBufNIO(byteBuffer);
86+
return byteBuffer;
8387
}
8488

8589
private ByteBuffer createNew(final int size) {
@@ -88,7 +92,7 @@ private ByteBuffer createNew(final int size) {
8892
return buf;
8993
}
9094

91-
private void release(final ByteBuffer buffer) {
95+
public void release(final ByteBuffer buffer) {
9296
ConcurrentPool<ByteBuffer> pool = powerOfTwoToPoolMap.get(log2(roundUpToNextHighestPowerOfTwo(buffer.capacity())));
9397
if (pool != null) {
9498
pool.release(buffer);

0 commit comments

Comments
 (0)