diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt index 2574e4fc9..63d448de3 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt @@ -33,6 +33,7 @@ import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.nio.NioDatagramChannel import io.netty.handler.ssl.ClientAuth import io.netty.incubator.codec.quic.* +import org.slf4j.LoggerFactory import java.net.* import java.time.Duration import java.util.* @@ -44,6 +45,7 @@ class QuicTransport( private val certAlgorithm: String, private val protocols: List> ) : NettyTransport { + private val log = LoggerFactory.getLogger(QuicTransport::class.java) private var closed = false var connectTimeout = Duration.ofSeconds(15) @@ -162,7 +164,7 @@ class QuicTransport( listeners -= addr } } - println("Quic server listening on " + addr) + log.info("Quic server listening on {}", addr) res.complete(null) } } @@ -220,29 +222,7 @@ class QuicTransport( connFuture.also { registerChannel(it.get()) val connection = ConnectionOverNetty(it.get(), this, true) - connection.setMuxerSession(object : StreamMuxer.Session { - override fun createStream(protocols: List>): StreamPromise { - var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1 - var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol } - val multi = streamMultistreamProtocol.createMultistream(protocols) - - val controller = CompletableFuture() - val streamFut = CompletableFuture() - it.get().createStream( - QuicStreamType.BIDIRECTIONAL, - object : ChannelInboundHandlerAdapter() { - override fun handlerAdded(ctx: ChannelHandlerContext?) { - val stream = createStream(ctx!!.channel(), connection, true) - ctx.channel().attr(STREAM).set(stream) - val streamHandler = multi.toStreamHandler() - streamHandler.handleStream(stream).forward(controller) - .apply { streamFut.complete(stream) } - } - } - ) - return StreamPromise(streamFut, controller) - } - }) + connection.setMuxerSession(QuicMuxerSession(it.get(), connection)) val pubHash = Multihash.of(addr.getPeerId()!!.bytes.toByteBuf()) val remotePubKey = if (pubHash.desc.digest == Multihash.Digest.Identity) { unmarshalPublicKey(pubHash.bytes.toByteArray()) @@ -305,7 +285,7 @@ class QuicTransport( val javaPrivateKey = getJavaKey(connectionKeys.first) val isClient = expectedRemotePeerId != null val cert = buildCert(localKey, connectionKeys.first) - println("Building " + certAlgorithm + " keys and cert for peerid " + PeerId.fromPubKey(localKey.publicKey())) + log.info("Building {} keys and cert for peerid {}", certAlgorithm, PeerId.fromPubKey(localKey.publicKey())) return ( if (isClient) { QuicSslContextBuilder.forClient().keyManager(javaPrivateKey, null, cert) @@ -313,21 +293,6 @@ class QuicTransport( QuicSslContextBuilder.forServer(javaPrivateKey, null, cert).clientAuth(ClientAuth.REQUIRE) } ) -// .option(BoringSSLContextOption.GROUPS, arrayOf("x25519")) -// .option( -// BoringSSLContextOption.SIGNATURE_ALGORITHMS, -// arrayOf( -// // "ed25519", -// "ecdsa_secp256r1_sha256", -// "rsa_pkcs1_sha256", -// "rsa_pss_rsae_sha256", -// "ecdsa_secp384r1_sha384", -// "rsa_pkcs1_sha384", -// "rsa_pss_rsae_sha384", -// "rsa_pss_rsae_sha512", -// "rsa_pkcs1_sha512", -// ) -// ) .trustManager(trustManager) .applicationProtocols("libp2p") .build() @@ -337,7 +302,8 @@ class QuicTransport( connHandler: ConnectionHandler, preHandler: ChannelVisitor? ): ChannelHandler { - val sslContext = quicSslContext(null, Libp2pTrustManager(Optional.empty())) + val trustManager = Libp2pTrustManager(Optional.empty()) + val sslContext = quicSslContext(null, trustManager) return QuicServerCodecBuilder() .sslEngineProvider({ q -> sslContext.newEngine(q.alloc()) }) .maxIdleTimeout(5000, TimeUnit.MILLISECONDS) @@ -346,9 +312,53 @@ class QuicTransport( .handler(object : ChannelInitializer() { override fun initChannel(ch: Channel) { val connection = ConnectionOverNetty(ch, this@QuicTransport, false) + connection.setMuxerSession(QuicMuxerSession(ch as QuicChannel, connection)) ch.attr(CONNECTION).set(connection) - preHandler?.also { it.visit(connection) } - connHandler.handleConnection(connection) + + // Add a handler to wait for channel activation (handshake completion) + ch.pipeline().addFirst( + "quic-handshake-waiter", + object : ChannelInboundHandlerAdapter() { + override fun channelActive(ctx: ChannelHandlerContext) { + // Now the handshake is complete and remoteCert should be available + val remoteCert = trustManager.remoteCert + if (remoteCert != null) { + val remotePeerId = verifyAndExtractPeerId(arrayOf(remoteCert)) + val remotePublicKey = getPublicKeyFromCert(arrayOf(remoteCert)) + + log.info("Handshake completed with remote peer id: {}", remotePeerId) + + connection.setSecureSession( + SecureChannel.Session( + PeerId.fromPubKey(localKey.publicKey()), + remotePeerId, + remotePublicKey, + null + ) + ) + + // Remove this handler as it's no longer needed + ctx.pipeline().remove(this) + + // Now it's safe to call the connection handler + preHandler?.also { it.visit(connection) } + connHandler.handleConnection(connection) + } else { + // This should not happen if channelActive is called after handshake + ctx.close() + throw IllegalStateException("Remote certificate still not available after handshake") + } + + super.channelActive(ctx) + } + + @Deprecated("Deprecated in Java") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + log.error("An error during handshake", cause) + ctx.close() + } + } + ) } }) .initialMaxData(1024) @@ -359,6 +369,34 @@ class QuicTransport( .build() } + class QuicMuxerSession( + val ch: QuicChannel, + val connection: ConnectionOverNetty + ) : StreamMuxer.Session { + override fun createStream(protocols: List>): StreamPromise { + var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1 + var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol } + val multi = streamMultistreamProtocol.createMultistream(protocols) + + val controller = CompletableFuture() + val streamFut = CompletableFuture() + + ch.createStream( + QuicStreamType.BIDIRECTIONAL, + object : ChannelInboundHandlerAdapter() { + override fun handlerAdded(ctx: ChannelHandlerContext?) { + val stream = createStream(ctx!!.channel(), connection, true) + ctx.channel().attr(STREAM).set(stream) + val streamHandler = multi.toStreamHandler() + streamHandler.handleStream(stream).forward(controller) + .apply { streamFut.complete(stream) } + } + } + ) + return StreamPromise(streamFut, controller) + } + } + class InboundStreamHandler( val handler: MultistreamProtocol, val protocols: List>