Skip to content

Commit 3d4b05f

Browse files
authored
[QUIC] Set muxer session and secure session on connection (#413)
1 parent e4c25d5 commit 3d4b05f

File tree

1 file changed

+81
-43
lines changed

1 file changed

+81
-43
lines changed

libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import io.netty.channel.nio.NioEventLoopGroup
3333
import io.netty.channel.socket.nio.NioDatagramChannel
3434
import io.netty.handler.ssl.ClientAuth
3535
import io.netty.incubator.codec.quic.*
36+
import org.slf4j.LoggerFactory
3637
import java.net.*
3738
import java.time.Duration
3839
import java.util.*
@@ -44,6 +45,7 @@ class QuicTransport(
4445
private val certAlgorithm: String,
4546
private val protocols: List<ProtocolBinding<*>>
4647
) : NettyTransport {
48+
private val log = LoggerFactory.getLogger(QuicTransport::class.java)
4749

4850
private var closed = false
4951
var connectTimeout = Duration.ofSeconds(15)
@@ -162,7 +164,7 @@ class QuicTransport(
162164
listeners -= addr
163165
}
164166
}
165-
println("Quic server listening on " + addr)
167+
log.info("Quic server listening on {}", addr)
166168
res.complete(null)
167169
}
168170
}
@@ -220,29 +222,7 @@ class QuicTransport(
220222
connFuture.also {
221223
registerChannel(it.get())
222224
val connection = ConnectionOverNetty(it.get(), this, true)
223-
connection.setMuxerSession(object : StreamMuxer.Session {
224-
override fun <T> createStream(protocols: List<ProtocolBinding<T>>): StreamPromise<T> {
225-
var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
226-
var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
227-
val multi = streamMultistreamProtocol.createMultistream(protocols)
228-
229-
val controller = CompletableFuture<T>()
230-
val streamFut = CompletableFuture<Stream>()
231-
it.get().createStream(
232-
QuicStreamType.BIDIRECTIONAL,
233-
object : ChannelInboundHandlerAdapter() {
234-
override fun handlerAdded(ctx: ChannelHandlerContext?) {
235-
val stream = createStream(ctx!!.channel(), connection, true)
236-
ctx.channel().attr(STREAM).set(stream)
237-
val streamHandler = multi.toStreamHandler()
238-
streamHandler.handleStream(stream).forward(controller)
239-
.apply { streamFut.complete(stream) }
240-
}
241-
}
242-
)
243-
return StreamPromise(streamFut, controller)
244-
}
245-
})
225+
connection.setMuxerSession(QuicMuxerSession(it.get(), connection))
246226
val pubHash = Multihash.of(addr.getPeerId()!!.bytes.toByteBuf())
247227
val remotePubKey = if (pubHash.desc.digest == Multihash.Digest.Identity) {
248228
unmarshalPublicKey(pubHash.bytes.toByteArray())
@@ -305,29 +285,14 @@ class QuicTransport(
305285
val javaPrivateKey = getJavaKey(connectionKeys.first)
306286
val isClient = expectedRemotePeerId != null
307287
val cert = buildCert(localKey, connectionKeys.first)
308-
println("Building " + certAlgorithm + " keys and cert for peerid " + PeerId.fromPubKey(localKey.publicKey()))
288+
log.info("Building {} keys and cert for peerid {}", certAlgorithm, PeerId.fromPubKey(localKey.publicKey()))
309289
return (
310290
if (isClient) {
311291
QuicSslContextBuilder.forClient().keyManager(javaPrivateKey, null, cert)
312292
} else {
313293
QuicSslContextBuilder.forServer(javaPrivateKey, null, cert).clientAuth(ClientAuth.REQUIRE)
314294
}
315295
)
316-
// .option(BoringSSLContextOption.GROUPS, arrayOf("x25519"))
317-
// .option(
318-
// BoringSSLContextOption.SIGNATURE_ALGORITHMS,
319-
// arrayOf(
320-
// // "ed25519",
321-
// "ecdsa_secp256r1_sha256",
322-
// "rsa_pkcs1_sha256",
323-
// "rsa_pss_rsae_sha256",
324-
// "ecdsa_secp384r1_sha384",
325-
// "rsa_pkcs1_sha384",
326-
// "rsa_pss_rsae_sha384",
327-
// "rsa_pss_rsae_sha512",
328-
// "rsa_pkcs1_sha512",
329-
// )
330-
// )
331296
.trustManager(trustManager)
332297
.applicationProtocols("libp2p")
333298
.build()
@@ -337,7 +302,8 @@ class QuicTransport(
337302
connHandler: ConnectionHandler,
338303
preHandler: ChannelVisitor<P2PChannel>?
339304
): ChannelHandler {
340-
val sslContext = quicSslContext(null, Libp2pTrustManager(Optional.empty()))
305+
val trustManager = Libp2pTrustManager(Optional.empty())
306+
val sslContext = quicSslContext(null, trustManager)
341307
return QuicServerCodecBuilder()
342308
.sslEngineProvider({ q -> sslContext.newEngine(q.alloc()) })
343309
.maxIdleTimeout(5000, TimeUnit.MILLISECONDS)
@@ -346,9 +312,53 @@ class QuicTransport(
346312
.handler(object : ChannelInitializer<Channel>() {
347313
override fun initChannel(ch: Channel) {
348314
val connection = ConnectionOverNetty(ch, this@QuicTransport, false)
315+
connection.setMuxerSession(QuicMuxerSession(ch as QuicChannel, connection))
349316
ch.attr(CONNECTION).set(connection)
350-
preHandler?.also { it.visit(connection) }
351-
connHandler.handleConnection(connection)
317+
318+
// Add a handler to wait for channel activation (handshake completion)
319+
ch.pipeline().addFirst(
320+
"quic-handshake-waiter",
321+
object : ChannelInboundHandlerAdapter() {
322+
override fun channelActive(ctx: ChannelHandlerContext) {
323+
// Now the handshake is complete and remoteCert should be available
324+
val remoteCert = trustManager.remoteCert
325+
if (remoteCert != null) {
326+
val remotePeerId = verifyAndExtractPeerId(arrayOf(remoteCert))
327+
val remotePublicKey = getPublicKeyFromCert(arrayOf(remoteCert))
328+
329+
log.info("Handshake completed with remote peer id: {}", remotePeerId)
330+
331+
connection.setSecureSession(
332+
SecureChannel.Session(
333+
PeerId.fromPubKey(localKey.publicKey()),
334+
remotePeerId,
335+
remotePublicKey,
336+
null
337+
)
338+
)
339+
340+
// Remove this handler as it's no longer needed
341+
ctx.pipeline().remove(this)
342+
343+
// Now it's safe to call the connection handler
344+
preHandler?.also { it.visit(connection) }
345+
connHandler.handleConnection(connection)
346+
} else {
347+
// This should not happen if channelActive is called after handshake
348+
ctx.close()
349+
throw IllegalStateException("Remote certificate still not available after handshake")
350+
}
351+
352+
super.channelActive(ctx)
353+
}
354+
355+
@Deprecated("Deprecated in Java")
356+
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
357+
log.error("An error during handshake", cause)
358+
ctx.close()
359+
}
360+
}
361+
)
352362
}
353363
})
354364
.initialMaxData(1024)
@@ -359,6 +369,34 @@ class QuicTransport(
359369
.build()
360370
}
361371

372+
class QuicMuxerSession(
373+
val ch: QuicChannel,
374+
val connection: ConnectionOverNetty
375+
) : StreamMuxer.Session {
376+
override fun <T> createStream(protocols: List<ProtocolBinding<T>>): StreamPromise<T> {
377+
var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
378+
var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
379+
val multi = streamMultistreamProtocol.createMultistream(protocols)
380+
381+
val controller = CompletableFuture<T>()
382+
val streamFut = CompletableFuture<Stream>()
383+
384+
ch.createStream(
385+
QuicStreamType.BIDIRECTIONAL,
386+
object : ChannelInboundHandlerAdapter() {
387+
override fun handlerAdded(ctx: ChannelHandlerContext?) {
388+
val stream = createStream(ctx!!.channel(), connection, true)
389+
ctx.channel().attr(STREAM).set(stream)
390+
val streamHandler = multi.toStreamHandler()
391+
streamHandler.handleStream(stream).forward(controller)
392+
.apply { streamFut.complete(stream) }
393+
}
394+
}
395+
)
396+
return StreamPromise(streamFut, controller)
397+
}
398+
}
399+
362400
class InboundStreamHandler(
363401
val handler: MultistreamProtocol,
364402
val protocols: List<ProtocolBinding<*>>

0 commit comments

Comments
 (0)