Skip to content

set muxer session and secure session on connection #413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 81 additions & 43 deletions libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.nio.NioDatagramChannel
import io.netty.handler.ssl.ClientAuth
import io.netty.incubator.codec.quic.*
import org.slf4j.LoggerFactory
import java.net.*
import java.time.Duration
import java.util.*
Expand All @@ -44,6 +45,7 @@ class QuicTransport(
private val certAlgorithm: String,
private val protocols: List<ProtocolBinding<*>>
) : NettyTransport {
private val log = LoggerFactory.getLogger(QuicTransport::class.java)

private var closed = false
var connectTimeout = Duration.ofSeconds(15)
Expand Down Expand Up @@ -162,7 +164,7 @@ class QuicTransport(
listeners -= addr
}
}
println("Quic server listening on " + addr)
log.info("Quic server listening on {}", addr)
res.complete(null)
}
}
Expand Down Expand Up @@ -220,29 +222,7 @@ class QuicTransport(
connFuture.also {
registerChannel(it.get())
val connection = ConnectionOverNetty(it.get(), this, true)
connection.setMuxerSession(object : StreamMuxer.Session {
override fun <T> createStream(protocols: List<ProtocolBinding<T>>): StreamPromise<T> {
var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
val multi = streamMultistreamProtocol.createMultistream(protocols)

val controller = CompletableFuture<T>()
val streamFut = CompletableFuture<Stream>()
it.get().createStream(
QuicStreamType.BIDIRECTIONAL,
object : ChannelInboundHandlerAdapter() {
override fun handlerAdded(ctx: ChannelHandlerContext?) {
val stream = createStream(ctx!!.channel(), connection, true)
ctx.channel().attr(STREAM).set(stream)
val streamHandler = multi.toStreamHandler()
streamHandler.handleStream(stream).forward(controller)
.apply { streamFut.complete(stream) }
}
}
)
return StreamPromise(streamFut, controller)
}
})
connection.setMuxerSession(QuicMuxerSession(it.get(), connection))
val pubHash = Multihash.of(addr.getPeerId()!!.bytes.toByteBuf())
val remotePubKey = if (pubHash.desc.digest == Multihash.Digest.Identity) {
unmarshalPublicKey(pubHash.bytes.toByteArray())
Expand Down Expand Up @@ -305,29 +285,14 @@ class QuicTransport(
val javaPrivateKey = getJavaKey(connectionKeys.first)
val isClient = expectedRemotePeerId != null
val cert = buildCert(localKey, connectionKeys.first)
println("Building " + certAlgorithm + " keys and cert for peerid " + PeerId.fromPubKey(localKey.publicKey()))
log.info("Building {} keys and cert for peerid {}", certAlgorithm, PeerId.fromPubKey(localKey.publicKey()))
return (
if (isClient) {
QuicSslContextBuilder.forClient().keyManager(javaPrivateKey, null, cert)
} else {
QuicSslContextBuilder.forServer(javaPrivateKey, null, cert).clientAuth(ClientAuth.REQUIRE)
}
)
// .option(BoringSSLContextOption.GROUPS, arrayOf("x25519"))
// .option(
// BoringSSLContextOption.SIGNATURE_ALGORITHMS,
// arrayOf(
// // "ed25519",
// "ecdsa_secp256r1_sha256",
// "rsa_pkcs1_sha256",
// "rsa_pss_rsae_sha256",
// "ecdsa_secp384r1_sha384",
// "rsa_pkcs1_sha384",
// "rsa_pss_rsae_sha384",
// "rsa_pss_rsae_sha512",
// "rsa_pkcs1_sha512",
// )
// )
.trustManager(trustManager)
.applicationProtocols("libp2p")
.build()
Expand All @@ -337,7 +302,8 @@ class QuicTransport(
connHandler: ConnectionHandler,
preHandler: ChannelVisitor<P2PChannel>?
): ChannelHandler {
val sslContext = quicSslContext(null, Libp2pTrustManager(Optional.empty()))
val trustManager = Libp2pTrustManager(Optional.empty())
val sslContext = quicSslContext(null, trustManager)
return QuicServerCodecBuilder()
.sslEngineProvider({ q -> sslContext.newEngine(q.alloc()) })
.maxIdleTimeout(5000, TimeUnit.MILLISECONDS)
Expand All @@ -346,9 +312,53 @@ class QuicTransport(
.handler(object : ChannelInitializer<Channel>() {
override fun initChannel(ch: Channel) {
val connection = ConnectionOverNetty(ch, this@QuicTransport, false)
connection.setMuxerSession(QuicMuxerSession(ch as QuicChannel, connection))
ch.attr(CONNECTION).set(connection)
preHandler?.also { it.visit(connection) }
connHandler.handleConnection(connection)

// Add a handler to wait for channel activation (handshake completion)
ch.pipeline().addFirst(
"quic-handshake-waiter",
object : ChannelInboundHandlerAdapter() {
override fun channelActive(ctx: ChannelHandlerContext) {
// Now the handshake is complete and remoteCert should be available
val remoteCert = trustManager.remoteCert
if (remoteCert != null) {
val remotePeerId = verifyAndExtractPeerId(arrayOf(remoteCert))
val remotePublicKey = getPublicKeyFromCert(arrayOf(remoteCert))

log.info("Handshake completed with remote peer id: {}", remotePeerId)

connection.setSecureSession(
SecureChannel.Session(
PeerId.fromPubKey(localKey.publicKey()),
remotePeerId,
remotePublicKey,
null
)
)

// Remove this handler as it's no longer needed
ctx.pipeline().remove(this)

// Now it's safe to call the connection handler
preHandler?.also { it.visit(connection) }
connHandler.handleConnection(connection)
} else {
// This should not happen if channelActive is called after handshake
ctx.close()
throw IllegalStateException("Remote certificate still not available after handshake")
}

super.channelActive(ctx)
}

@Deprecated("Deprecated in Java")
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
log.error("An error during handshake", cause)
ctx.close()
}
}
)
}
})
.initialMaxData(1024)
Expand All @@ -359,6 +369,34 @@ class QuicTransport(
.build()
}

class QuicMuxerSession(
val ch: QuicChannel,
val connection: ConnectionOverNetty
) : StreamMuxer.Session {
override fun <T> createStream(protocols: List<ProtocolBinding<T>>): StreamPromise<T> {
var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
val multi = streamMultistreamProtocol.createMultistream(protocols)

val controller = CompletableFuture<T>()
val streamFut = CompletableFuture<Stream>()

ch.createStream(
QuicStreamType.BIDIRECTIONAL,
object : ChannelInboundHandlerAdapter() {
override fun handlerAdded(ctx: ChannelHandlerContext?) {
val stream = createStream(ctx!!.channel(), connection, true)
ctx.channel().attr(STREAM).set(stream)
val streamHandler = multi.toStreamHandler()
streamHandler.handleStream(stream).forward(controller)
.apply { streamFut.complete(stream) }
}
}
)
return StreamPromise(streamFut, controller)
}
}

class InboundStreamHandler(
val handler: MultistreamProtocol,
val protocols: List<ProtocolBinding<*>>
Expand Down
Loading