@@ -33,6 +33,7 @@ import io.netty.channel.nio.NioEventLoopGroup
33
33
import io.netty.channel.socket.nio.NioDatagramChannel
34
34
import io.netty.handler.ssl.ClientAuth
35
35
import io.netty.incubator.codec.quic.*
36
+ import org.slf4j.LoggerFactory
36
37
import java.net.*
37
38
import java.time.Duration
38
39
import java.util.*
@@ -44,6 +45,7 @@ class QuicTransport(
44
45
private val certAlgorithm : String ,
45
46
private val protocols : List <ProtocolBinding <* >>
46
47
) : NettyTransport {
48
+ private val log = LoggerFactory .getLogger(QuicTransport ::class .java)
47
49
48
50
private var closed = false
49
51
var connectTimeout = Duration .ofSeconds(15 )
@@ -162,7 +164,7 @@ class QuicTransport(
162
164
listeners - = addr
163
165
}
164
166
}
165
- println (" Quic server listening on " + addr)
167
+ log.info (" Quic server listening on {} " , addr)
166
168
res.complete(null )
167
169
}
168
170
}
@@ -220,29 +222,7 @@ class QuicTransport(
220
222
connFuture.also {
221
223
registerChannel(it.get())
222
224
val connection = ConnectionOverNetty (it.get(), this , true )
223
- connection.setMuxerSession(object : StreamMuxer .Session {
224
- override fun <T > createStream (protocols : List <ProtocolBinding <T >>): StreamPromise <T > {
225
- var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
226
- var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
227
- val multi = streamMultistreamProtocol.createMultistream(protocols)
228
-
229
- val controller = CompletableFuture <T >()
230
- val streamFut = CompletableFuture <Stream >()
231
- it.get().createStream(
232
- QuicStreamType .BIDIRECTIONAL ,
233
- object : ChannelInboundHandlerAdapter () {
234
- override fun handlerAdded (ctx : ChannelHandlerContext ? ) {
235
- val stream = createStream(ctx!! .channel(), connection, true )
236
- ctx.channel().attr(STREAM ).set(stream)
237
- val streamHandler = multi.toStreamHandler()
238
- streamHandler.handleStream(stream).forward(controller)
239
- .apply { streamFut.complete(stream) }
240
- }
241
- }
242
- )
243
- return StreamPromise (streamFut, controller)
244
- }
245
- })
225
+ connection.setMuxerSession(QuicMuxerSession (it.get(), connection))
246
226
val pubHash = Multihash .of(addr.getPeerId()!! .bytes.toByteBuf())
247
227
val remotePubKey = if (pubHash.desc.digest == Multihash .Digest .Identity ) {
248
228
unmarshalPublicKey(pubHash.bytes.toByteArray())
@@ -305,29 +285,14 @@ class QuicTransport(
305
285
val javaPrivateKey = getJavaKey(connectionKeys.first)
306
286
val isClient = expectedRemotePeerId != null
307
287
val cert = buildCert(localKey, connectionKeys.first)
308
- println (" Building " + certAlgorithm + " keys and cert for peerid " + PeerId .fromPubKey(localKey.publicKey()))
288
+ log.info (" Building {} keys and cert for peerid {} " , certAlgorithm, PeerId .fromPubKey(localKey.publicKey()))
309
289
return (
310
290
if (isClient) {
311
291
QuicSslContextBuilder .forClient().keyManager(javaPrivateKey, null , cert)
312
292
} else {
313
293
QuicSslContextBuilder .forServer(javaPrivateKey, null , cert).clientAuth(ClientAuth .REQUIRE )
314
294
}
315
295
)
316
- // .option(BoringSSLContextOption.GROUPS, arrayOf("x25519"))
317
- // .option(
318
- // BoringSSLContextOption.SIGNATURE_ALGORITHMS,
319
- // arrayOf(
320
- // // "ed25519",
321
- // "ecdsa_secp256r1_sha256",
322
- // "rsa_pkcs1_sha256",
323
- // "rsa_pss_rsae_sha256",
324
- // "ecdsa_secp384r1_sha384",
325
- // "rsa_pkcs1_sha384",
326
- // "rsa_pss_rsae_sha384",
327
- // "rsa_pss_rsae_sha512",
328
- // "rsa_pkcs1_sha512",
329
- // )
330
- // )
331
296
.trustManager(trustManager)
332
297
.applicationProtocols(" libp2p" )
333
298
.build()
@@ -337,7 +302,8 @@ class QuicTransport(
337
302
connHandler : ConnectionHandler ,
338
303
preHandler : ChannelVisitor <P2PChannel >?
339
304
): ChannelHandler {
340
- val sslContext = quicSslContext(null , Libp2pTrustManager (Optional .empty()))
305
+ val trustManager = Libp2pTrustManager (Optional .empty())
306
+ val sslContext = quicSslContext(null , trustManager)
341
307
return QuicServerCodecBuilder ()
342
308
.sslEngineProvider({ q -> sslContext.newEngine(q.alloc()) })
343
309
.maxIdleTimeout(5000 , TimeUnit .MILLISECONDS )
@@ -346,9 +312,53 @@ class QuicTransport(
346
312
.handler(object : ChannelInitializer <Channel >() {
347
313
override fun initChannel (ch : Channel ) {
348
314
val connection = ConnectionOverNetty (ch, this @QuicTransport, false )
315
+ connection.setMuxerSession(QuicMuxerSession (ch as QuicChannel , connection))
349
316
ch.attr(CONNECTION ).set(connection)
350
- preHandler?.also { it.visit(connection) }
351
- connHandler.handleConnection(connection)
317
+
318
+ // Add a handler to wait for channel activation (handshake completion)
319
+ ch.pipeline().addFirst(
320
+ " quic-handshake-waiter" ,
321
+ object : ChannelInboundHandlerAdapter () {
322
+ override fun channelActive (ctx : ChannelHandlerContext ) {
323
+ // Now the handshake is complete and remoteCert should be available
324
+ val remoteCert = trustManager.remoteCert
325
+ if (remoteCert != null ) {
326
+ val remotePeerId = verifyAndExtractPeerId(arrayOf(remoteCert))
327
+ val remotePublicKey = getPublicKeyFromCert(arrayOf(remoteCert))
328
+
329
+ log.info(" Handshake completed with remote peer id: {}" , remotePeerId)
330
+
331
+ connection.setSecureSession(
332
+ SecureChannel .Session (
333
+ PeerId .fromPubKey(localKey.publicKey()),
334
+ remotePeerId,
335
+ remotePublicKey,
336
+ null
337
+ )
338
+ )
339
+
340
+ // Remove this handler as it's no longer needed
341
+ ctx.pipeline().remove(this )
342
+
343
+ // Now it's safe to call the connection handler
344
+ preHandler?.also { it.visit(connection) }
345
+ connHandler.handleConnection(connection)
346
+ } else {
347
+ // This should not happen if channelActive is called after handshake
348
+ ctx.close()
349
+ throw IllegalStateException (" Remote certificate still not available after handshake" )
350
+ }
351
+
352
+ super .channelActive(ctx)
353
+ }
354
+
355
+ @Deprecated(" Deprecated in Java" )
356
+ override fun exceptionCaught (ctx : ChannelHandlerContext , cause : Throwable ) {
357
+ log.error(" An error during handshake" , cause)
358
+ ctx.close()
359
+ }
360
+ }
361
+ )
352
362
}
353
363
})
354
364
.initialMaxData(1024 )
@@ -359,6 +369,34 @@ class QuicTransport(
359
369
.build()
360
370
}
361
371
372
+ class QuicMuxerSession (
373
+ val ch : QuicChannel ,
374
+ val connection : ConnectionOverNetty
375
+ ) : StreamMuxer.Session {
376
+ override fun <T > createStream (protocols : List <ProtocolBinding <T >>): StreamPromise <T > {
377
+ var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
378
+ var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
379
+ val multi = streamMultistreamProtocol.createMultistream(protocols)
380
+
381
+ val controller = CompletableFuture <T >()
382
+ val streamFut = CompletableFuture <Stream >()
383
+
384
+ ch.createStream(
385
+ QuicStreamType .BIDIRECTIONAL ,
386
+ object : ChannelInboundHandlerAdapter () {
387
+ override fun handlerAdded (ctx : ChannelHandlerContext ? ) {
388
+ val stream = createStream(ctx!! .channel(), connection, true )
389
+ ctx.channel().attr(STREAM ).set(stream)
390
+ val streamHandler = multi.toStreamHandler()
391
+ streamHandler.handleStream(stream).forward(controller)
392
+ .apply { streamFut.complete(stream) }
393
+ }
394
+ }
395
+ )
396
+ return StreamPromise (streamFut, controller)
397
+ }
398
+ }
399
+
362
400
class InboundStreamHandler (
363
401
val handler : MultistreamProtocol ,
364
402
val protocols : List <ProtocolBinding <* >>
0 commit comments