23
23
import com .mongodb .diagnostics .logging .Loggers ;
24
24
import com .mongodb .internal .connection .AsynchronousChannelStream ;
25
25
import com .mongodb .internal .connection .ConcurrentLinkedDeque ;
26
+ import com .mongodb .internal .connection .ExtendedAsynchronousByteChannel ;
26
27
import com .mongodb .internal .connection .PowerOfTwoBufferPool ;
27
28
import com .mongodb .internal .connection .tlschannel .BufferAllocator ;
28
29
import com .mongodb .internal .connection .tlschannel .ClientTlsChannel ;
29
30
import com .mongodb .internal .connection .tlschannel .TlsChannel ;
30
31
import com .mongodb .internal .connection .tlschannel .async .AsynchronousTlsChannel ;
31
32
import com .mongodb .internal .connection .tlschannel .async .AsynchronousTlsChannelGroup ;
32
- import org .bson .ByteBuf ;
33
33
34
34
import javax .net .ssl .SSLContext ;
35
35
import javax .net .ssl .SSLEngine ;
36
36
import javax .net .ssl .SSLParameters ;
37
37
import java .io .Closeable ;
38
38
import java .io .IOException ;
39
39
import java .net .StandardSocketOptions ;
40
+ import java .nio .ByteBuffer ;
41
+ import java .nio .channels .CompletionHandler ;
40
42
import java .nio .channels .SelectionKey ;
41
43
import java .nio .channels .Selector ;
42
44
import java .nio .channels .SocketChannel ;
43
45
import java .security .NoSuchAlgorithmException ;
44
46
import java .util .Iterator ;
47
+ import java .util .concurrent .Future ;
48
+ import java .util .concurrent .TimeUnit ;
45
49
46
50
import static com .mongodb .assertions .Assertions .isTrue ;
47
51
import static com .mongodb .internal .connection .SslHelper .enableHostNameVerification ;
@@ -89,12 +93,7 @@ private TlsChannelStreamFactoryFactory(final AsynchronousTlsChannelGroup group,
89
93
90
94
@ Override
91
95
public StreamFactory create (final SocketSettings socketSettings , final SslSettings sslSettings ) {
92
- return new StreamFactory () {
93
- @ Override
94
- public Stream create (final ServerAddress serverAddress ) {
95
- return new TlsChannelStream (serverAddress , socketSettings , sslSettings , bufferPool , group , selectorMonitor );
96
- }
97
- };
96
+ return serverAddress -> new TlsChannelStream (serverAddress , socketSettings , sslSettings , bufferPool , group , selectorMonitor );
98
97
}
99
98
100
99
@ Override
@@ -119,7 +118,7 @@ private Pair(final SocketChannel socketChannel, final Runnable attachment) {
119
118
120
119
private final Selector selector ;
121
120
private volatile boolean isClosed ;
122
- private final ConcurrentLinkedDeque <Pair > pendingRegistrations = new ConcurrentLinkedDeque <Pair >();
121
+ private final ConcurrentLinkedDeque <Pair > pendingRegistrations = new ConcurrentLinkedDeque <>();
123
122
124
123
SelectorMonitor () {
125
124
try {
@@ -130,39 +129,34 @@ private Pair(final SocketChannel socketChannel, final Runnable attachment) {
130
129
}
131
130
132
131
void start () {
133
- Thread selectorThread = new Thread (new Runnable () {
134
- @ Override
135
- public void run () {
136
- try {
137
- while (!isClosed ) {
138
- try {
139
- selector .select ();
140
-
141
- for (SelectionKey selectionKey : selector .selectedKeys ()) {
142
- selectionKey .cancel ();
143
- Runnable runnable = (Runnable ) selectionKey .attachment ();
144
- runnable .run ();
145
- }
146
-
147
- for (Iterator <Pair > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
148
- Pair pendingRegistration = iter .next ();
149
- pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT ,
150
- pendingRegistration .attachment );
151
- iter .remove ();
152
- }
153
- } catch (IOException e ) {
154
- LOGGER .warn ("Exception in selector loop" , e );
155
- } catch (RuntimeException e ) {
156
- LOGGER .warn ("Exception in selector loop" , e );
157
- }
158
- }
159
- } finally {
132
+ Thread selectorThread = new Thread (() -> {
133
+ try {
134
+ while (!isClosed ) {
160
135
try {
161
- selector .close ();
162
- } catch (IOException e ) {
163
- // ignore
136
+ selector .select ();
137
+
138
+ for (SelectionKey selectionKey : selector .selectedKeys ()) {
139
+ selectionKey .cancel ();
140
+ Runnable runnable = (Runnable ) selectionKey .attachment ();
141
+ runnable .run ();
142
+ }
143
+
144
+ for (Iterator <Pair > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
145
+ Pair pendingRegistration = iter .next ();
146
+ pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT ,
147
+ pendingRegistration .attachment );
148
+ iter .remove ();
149
+ }
150
+ } catch (IOException | RuntimeException e ) {
151
+ LOGGER .warn ("Exception in selector loop" , e );
164
152
}
165
153
}
154
+ } finally {
155
+ try {
156
+ selector .close ();
157
+ } catch (IOException e ) {
158
+ // ignore
159
+ }
166
160
}
167
161
});
168
162
selectorThread .setDaemon (true );
@@ -188,7 +182,7 @@ private static class TlsChannelStream extends AsynchronousChannelStream implemen
188
182
private final SslSettings sslSettings ;
189
183
190
184
TlsChannelStream (final ServerAddress serverAddress , final SocketSettings settings , final SslSettings sslSettings ,
191
- final BufferProvider bufferProvider , final AsynchronousTlsChannelGroup group ,
185
+ final PowerOfTwoBufferPool bufferProvider , final AsynchronousTlsChannelGroup group ,
192
186
final SelectorMonitor selectorMonitor ) {
193
187
super (serverAddress , settings , bufferProvider );
194
188
this .sslSettings = sslSettings ;
@@ -219,42 +213,39 @@ public void openAsync(final AsyncCompletionHandler<Void> handler) {
219
213
220
214
socketChannel .connect (getServerAddress ().getSocketAddress ());
221
215
222
- selectorMonitor .register (socketChannel , new Runnable () {
223
- @ Override
224
- public void run () {
225
- try {
226
- if (!socketChannel .finishConnect ()) {
227
- throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
228
- }
216
+ selectorMonitor .register (socketChannel , () -> {
217
+ try {
218
+ if (!socketChannel .finishConnect ()) {
219
+ throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
220
+ }
229
221
230
- SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
231
- getServerAddress ().getPort ());
232
- sslEngine .setUseClientMode (true );
222
+ SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
223
+ getServerAddress ().getPort ());
224
+ sslEngine .setUseClientMode (true );
233
225
234
- SSLParameters sslParameters = sslEngine .getSSLParameters ();
235
- enableSni (getServerAddress ().getHost (), sslParameters );
226
+ SSLParameters sslParameters = sslEngine .getSSLParameters ();
227
+ enableSni (getServerAddress ().getHost (), sslParameters );
236
228
237
- if (!sslSettings .isInvalidHostNameAllowed ()) {
238
- enableHostNameVerification (sslParameters );
239
- }
240
- sslEngine .setSSLParameters (sslParameters );
229
+ if (!sslSettings .isInvalidHostNameAllowed ()) {
230
+ enableHostNameVerification (sslParameters );
231
+ }
232
+ sslEngine .setSSLParameters (sslParameters );
241
233
242
- BufferAllocator bufferAllocator = new BufferProviderAllocator ();
234
+ BufferAllocator bufferAllocator = new BufferProviderAllocator ();
243
235
244
- TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
245
- .withEncryptedBufferAllocator (bufferAllocator )
246
- .withPlainBufferAllocator (bufferAllocator )
247
- .build ();
236
+ TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
237
+ .withEncryptedBufferAllocator (bufferAllocator )
238
+ .withPlainBufferAllocator (bufferAllocator )
239
+ .build ();
248
240
249
- // build asynchronous channel, based in the TLS channel and associated with the global group.
250
- setChannel (new AsynchronousTlsChannel (group , tlsChannel , socketChannel ));
241
+ // build asynchronous channel, based in the TLS channel and associated with the global group.
242
+ setChannel (new AsynchronousTlsChannelAdapter ( new AsynchronousTlsChannel (group , tlsChannel , socketChannel ) ));
251
243
252
- handler .completed (null );
253
- } catch (IOException e ) {
254
- handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
255
- } catch (Throwable t ) {
256
- handler .failed (t );
257
- }
244
+ handler .completed (null );
245
+ } catch (IOException e ) {
246
+ handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
247
+ } catch (Throwable t ) {
248
+ handler .failed (t );
258
249
}
259
250
});
260
251
} catch (IOException e ) {
@@ -274,13 +265,75 @@ private SSLContext getSslContext() {
274
265
275
266
private class BufferProviderAllocator implements BufferAllocator {
276
267
@ Override
277
- public ByteBuf allocate (final int size ) {
278
- return getBufferProvider ().getBuffer (size );
268
+ public ByteBuffer allocate (final int size ) {
269
+ return getBufferProvider ().getByteBuffer (size );
270
+ }
271
+
272
+ @ Override
273
+ public void free (final ByteBuffer buffer ) {
274
+ getBufferProvider ().release (buffer );
275
+ }
276
+ }
277
+
278
+ public static class AsynchronousTlsChannelAdapter implements ExtendedAsynchronousByteChannel {
279
+ private final AsynchronousTlsChannel wrapped ;
280
+
281
+ AsynchronousTlsChannelAdapter (final AsynchronousTlsChannel wrapped ) {
282
+ this .wrapped = wrapped ;
283
+ }
284
+
285
+ @ Override
286
+ public <A > void read (final ByteBuffer dst , final A attach , final CompletionHandler <Integer , ? super A > handler ) {
287
+ wrapped .read (dst , attach , handler );
288
+ }
289
+
290
+ @ Override
291
+ public <A > void read (final ByteBuffer dst , final long timeout , final TimeUnit unit , final A attach ,
292
+ final CompletionHandler <Integer , ? super A > handler ) {
293
+ wrapped .read (dst , timeout , unit , attach , handler );
294
+ }
295
+
296
+ @ Override
297
+ public <A > void read (final ByteBuffer [] dsts , final int offset , final int length , final long timeout , final TimeUnit unit ,
298
+ final A attach , final CompletionHandler <Long , ? super A > handler ) {
299
+ wrapped .read (dsts , offset , length , timeout , unit , attach , handler );
300
+ }
301
+
302
+ @ Override
303
+ public Future <Integer > read (final ByteBuffer dst ) {
304
+ return wrapped .read (dst );
305
+ }
306
+
307
+ @ Override
308
+ public <A > void write (final ByteBuffer src , final A attach , final CompletionHandler <Integer , ? super A > handler ) {
309
+ wrapped .write (src , attach , handler );
310
+ }
311
+
312
+ @ Override
313
+ public <A > void write (final ByteBuffer src , final long timeout , final TimeUnit unit , final A attach ,
314
+ final CompletionHandler <Integer , ? super A > handler ) {
315
+ wrapped .write (src , timeout , unit , attach , handler );
316
+ }
317
+
318
+ @ Override
319
+ public <A > void write (final ByteBuffer [] srcs , final int offset , final int length , final long timeout , final TimeUnit unit ,
320
+ final A attach , final CompletionHandler <Long , ? super A > handler ) {
321
+ wrapped .write (srcs , offset , length , timeout , unit , attach , handler );
322
+ }
323
+
324
+ @ Override
325
+ public Future <Integer > write (final ByteBuffer src ) {
326
+ return wrapped .write (src );
327
+ }
328
+
329
+ @ Override
330
+ public boolean isOpen () {
331
+ return wrapped .isOpen ();
279
332
}
280
333
281
334
@ Override
282
- public void free ( final ByteBuf buffer ) {
283
- buffer . release ();
335
+ public void close () throws IOException {
336
+ wrapped . close ();
284
337
}
285
338
}
286
339
}
0 commit comments