Skip to content

Commit ac6609d

Browse files
committed
set muxer session and secure session on connection
1 parent 8b62e1f commit ac6609d

File tree

1 file changed

+78
-43
lines changed

1 file changed

+78
-43
lines changed

libp2p/src/main/kotlin/io/libp2p/transport/quic/QuicTransport.kt

Lines changed: 78 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import io.libp2p.etc.STREAM
1919
import io.libp2p.etc.types.*
2020
import io.libp2p.etc.util.MultiaddrUtils
2121
import io.libp2p.etc.util.netty.nettyInitializer
22+
import io.libp2p.security.secio.SecIoCodec
2223
import io.libp2p.security.tls.*
2324
import io.libp2p.transport.implementation.ConnectionOverNetty
2425
import io.libp2p.transport.implementation.NettyTransport
@@ -33,6 +34,7 @@ import io.netty.channel.nio.NioEventLoopGroup
3334
import io.netty.channel.socket.nio.NioDatagramChannel
3435
import io.netty.handler.ssl.ClientAuth
3536
import io.netty.incubator.codec.quic.*
37+
import org.slf4j.LoggerFactory
3638
import java.net.*
3739
import java.time.Duration
3840
import java.util.*
@@ -44,6 +46,7 @@ class QuicTransport(
4446
private val certAlgorithm: String,
4547
private val protocols: List<ProtocolBinding<*>>
4648
) : NettyTransport {
49+
private val log = LoggerFactory.getLogger(QuicTransport::class.java)
4750

4851
private var closed = false
4952
var connectTimeout = Duration.ofSeconds(15)
@@ -162,7 +165,7 @@ class QuicTransport(
162165
listeners -= addr
163166
}
164167
}
165-
println("Quic server listening on " + addr)
168+
log.info("Quic server listening on {}", addr)
166169
res.complete(null)
167170
}
168171
}
@@ -220,29 +223,7 @@ class QuicTransport(
220223
connFuture.also {
221224
registerChannel(it.get())
222225
val connection = ConnectionOverNetty(it.get(), this, true)
223-
connection.setMuxerSession(object : StreamMuxer.Session {
224-
override fun <T> createStream(protocols: List<ProtocolBinding<T>>): StreamPromise<T> {
225-
var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
226-
var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
227-
val multi = streamMultistreamProtocol.createMultistream(protocols)
228-
229-
val controller = CompletableFuture<T>()
230-
val streamFut = CompletableFuture<Stream>()
231-
it.get().createStream(
232-
QuicStreamType.BIDIRECTIONAL,
233-
object : ChannelInboundHandlerAdapter() {
234-
override fun handlerAdded(ctx: ChannelHandlerContext?) {
235-
val stream = createStream(ctx!!.channel(), connection, true)
236-
ctx.channel().attr(STREAM).set(stream)
237-
val streamHandler = multi.toStreamHandler()
238-
streamHandler.handleStream(stream).forward(controller)
239-
.apply { streamFut.complete(stream) }
240-
}
241-
}
242-
)
243-
return StreamPromise(streamFut, controller)
244-
}
245-
})
226+
connection.setMuxerSession(QuicMuxerSession(it.get(), connection))
246227
val pubHash = Multihash.of(addr.getPeerId()!!.bytes.toByteBuf())
247228
val remotePubKey = if (pubHash.desc.digest == Multihash.Digest.Identity) {
248229
unmarshalPublicKey(pubHash.bytes.toByteArray())
@@ -305,29 +286,14 @@ class QuicTransport(
305286
val javaPrivateKey = getJavaKey(connectionKeys.first)
306287
val isClient = expectedRemotePeerId != null
307288
val cert = buildCert(localKey, connectionKeys.first)
308-
println("Building " + certAlgorithm + " keys and cert for peerid " + PeerId.fromPubKey(localKey.publicKey()))
289+
log.info("Building {} keys and cert for peerid {}", certAlgorithm, PeerId.fromPubKey(localKey.publicKey()))
309290
return (
310291
if (isClient) {
311292
QuicSslContextBuilder.forClient().keyManager(javaPrivateKey, null, cert)
312293
} else {
313294
QuicSslContextBuilder.forServer(javaPrivateKey, null, cert).clientAuth(ClientAuth.REQUIRE)
314295
}
315296
)
316-
// .option(BoringSSLContextOption.GROUPS, arrayOf("x25519"))
317-
// .option(
318-
// BoringSSLContextOption.SIGNATURE_ALGORITHMS,
319-
// arrayOf(
320-
// // "ed25519",
321-
// "ecdsa_secp256r1_sha256",
322-
// "rsa_pkcs1_sha256",
323-
// "rsa_pss_rsae_sha256",
324-
// "ecdsa_secp384r1_sha384",
325-
// "rsa_pkcs1_sha384",
326-
// "rsa_pss_rsae_sha384",
327-
// "rsa_pss_rsae_sha512",
328-
// "rsa_pkcs1_sha512",
329-
// )
330-
// )
331297
.trustManager(trustManager)
332298
.applicationProtocols("libp2p")
333299
.build()
@@ -337,7 +303,8 @@ class QuicTransport(
337303
connHandler: ConnectionHandler,
338304
preHandler: ChannelVisitor<P2PChannel>?
339305
): ChannelHandler {
340-
val sslContext = quicSslContext(null, Libp2pTrustManager(Optional.empty()))
306+
val trustManager = Libp2pTrustManager(Optional.empty())
307+
val sslContext = quicSslContext(null, trustManager)
341308
return QuicServerCodecBuilder()
342309
.sslEngineProvider({ q -> sslContext.newEngine(q.alloc()) })
343310
.maxIdleTimeout(5000, TimeUnit.MILLISECONDS)
@@ -346,9 +313,49 @@ class QuicTransport(
346313
.handler(object : ChannelInitializer<Channel>() {
347314
override fun initChannel(ch: Channel) {
348315
val connection = ConnectionOverNetty(ch, this@QuicTransport, false)
316+
connection.setMuxerSession(QuicMuxerSession(ch as QuicChannel, connection))
349317
ch.attr(CONNECTION).set(connection)
350-
preHandler?.also { it.visit(connection) }
351-
connHandler.handleConnection(connection)
318+
319+
// Add a handler to wait for channel activation (handshake completion)
320+
ch.pipeline().addFirst("quic-handshake-waiter", object : ChannelInboundHandlerAdapter() {
321+
override fun channelActive(ctx: ChannelHandlerContext) {
322+
// Now the handshake is complete and remoteCert should be available
323+
val remoteCert = trustManager.remoteCert
324+
if (remoteCert != null) {
325+
val remotePeerId = verifyAndExtractPeerId(arrayOf(remoteCert))
326+
val remotePublicKey = getPublicKeyFromCert(arrayOf(remoteCert))
327+
328+
log.info("Handshake completed with remote peer id: {}", remotePeerId)
329+
330+
connection.setSecureSession(
331+
SecureChannel.Session(
332+
PeerId.fromPubKey(localKey.publicKey()),
333+
remotePeerId,
334+
remotePublicKey,
335+
null
336+
)
337+
)
338+
339+
// Remove this handler as it's no longer needed
340+
ctx.pipeline().remove(this)
341+
342+
// Now it's safe to call the connection handler
343+
preHandler?.also { it.visit(connection) }
344+
connHandler.handleConnection(connection)
345+
} else {
346+
// This should not happen if channelActive is called after handshake
347+
ctx.close()
348+
throw IllegalStateException("Remote certificate still not available after handshake")
349+
}
350+
351+
super.channelActive(ctx)
352+
}
353+
354+
@Deprecated("Deprecated in Java")
355+
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
356+
ctx.close()
357+
}
358+
})
352359
}
353360
})
354361
.initialMaxData(1024)
@@ -359,6 +366,34 @@ class QuicTransport(
359366
.build()
360367
}
361368

369+
class QuicMuxerSession(
370+
val ch : QuicChannel,
371+
val connection: ConnectionOverNetty
372+
) : StreamMuxer.Session {
373+
override fun <T> createStream(protocols: List<ProtocolBinding<T>>): StreamPromise<T> {
374+
var multistreamProtocol: MultistreamProtocol = MultistreamProtocolV1
375+
var streamMultistreamProtocol: MultistreamProtocol by lazyVar { multistreamProtocol }
376+
val multi = streamMultistreamProtocol.createMultistream(protocols)
377+
378+
val controller = CompletableFuture<T>()
379+
val streamFut = CompletableFuture<Stream>()
380+
381+
ch.createStream(
382+
QuicStreamType.BIDIRECTIONAL,
383+
object : ChannelInboundHandlerAdapter() {
384+
override fun handlerAdded(ctx: ChannelHandlerContext?) {
385+
val stream = createStream(ctx!!.channel(), connection, true)
386+
ctx.channel().attr(STREAM).set(stream)
387+
val streamHandler = multi.toStreamHandler()
388+
streamHandler.handleStream(stream).forward(controller)
389+
.apply { streamFut.complete(stream) }
390+
}
391+
}
392+
)
393+
return StreamPromise(streamFut, controller)
394+
}
395+
}
396+
362397
class InboundStreamHandler(
363398
val handler: MultistreamProtocol,
364399
val protocols: List<ProtocolBinding<*>>

0 commit comments

Comments
 (0)