diff --git a/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt index 61b58282b..da798bc58 100644 --- a/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt +++ b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt @@ -11,6 +11,8 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.rpc.RpcCall import kotlinx.rpc.RpcClient import kotlinx.rpc.annotations.Rpc @@ -47,7 +49,7 @@ import kotlin.properties.Delegates public abstract class InitializedKrpcClient( private val config: KrpcConfig.Client, private val transport: KrpcTransport, -): KrpcClient() { +) : KrpcClient() { final override suspend fun initializeTransport(): KrpcTransport { return transport } @@ -179,18 +181,28 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint { // callId to serviceTypeString private val cancellingRequests = RpcInternalConcurrentHashMap() + private val transportInitializationLock = Mutex() + /** * Starts the handshake process and awaits for completion. * If the handshake was completed before, nothing happens. */ private suspend fun initializeAndAwaitHandshakeCompletion() { - transport = initializeTransport() - isTransportReady = true + if (!isTransportReady) { + transportInitializationLock.withLock { + if (isTransportReady) { + return@withLock + } - connector.subscribeToGenericMessages(::handleGenericMessage) - connector.subscribeToProtocolMessages(::handleProtocolMessage) + transport = initializeTransport() + isTransportReady = true - connector.sendMessage(KrpcProtocolMessage.Handshake(KrpcPlugin.ALL)) + connector.subscribeToGenericMessages(::handleGenericMessage) + connector.subscribeToProtocolMessages(::handleProtocolMessage) + + connector.sendMessage(KrpcProtocolMessage.Handshake(KrpcPlugin.ALL)) + } + } serverSupportedPlugins.await() } diff --git a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt index 5f4e6d487..6be7387f2 100644 --- a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt +++ b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt @@ -10,7 +10,11 @@ import kotlinx.coroutines.test.TestResult import kotlinx.coroutines.test.TestScope import kotlinx.rpc.* import kotlinx.rpc.annotations.Rpc +import kotlinx.rpc.krpc.KrpcConfig import kotlinx.rpc.krpc.KrpcConfigBuilder +import kotlinx.rpc.krpc.KrpcTransport +import kotlinx.rpc.krpc.client.KrpcClient +import kotlinx.rpc.krpc.internal.KrpcProtocolMessage import kotlinx.rpc.krpc.internal.logging.RpcInternalCommonLogger import kotlinx.rpc.krpc.internal.logging.RpcInternalDumpLogger import kotlinx.rpc.krpc.internal.logging.RpcInternalDumpLoggerContainer @@ -76,21 +80,24 @@ class TransportTest { return KrpcTestServer(serverConfig, localTransport.server) } - private fun runTest(block: suspend TestScope.() -> Unit): TestResult = + private fun runTest(block: suspend TestScope.(logs: List) -> Unit): TestResult = kotlinx.coroutines.test.runTest(timeout = 20.seconds) { debugCoroutines() val logger = RpcInternalCommonLogger.logger("TransportTest") + val logs = mutableListOf() RpcInternalDumpLoggerContainer.set(object : RpcInternalDumpLogger { override val isEnabled: Boolean = true override fun dump(vararg tags: String, message: () -> String) { - logger.info { "${tags.joinToString(" ") { "[$it]" }} ${message()}" } + val message = "${tags.joinToString(" ") { "[$it]" }} ${message()}" + logs.add(message) + logger.info { message } } }) - block() + block(logs) RpcInternalDumpLoggerContainer.set(null) } @@ -240,6 +247,38 @@ class TransportTest { transports.cancel() } + private val clientHandshake = ".*\\[Client] \\[Send] \\{\"type\":\"${KrpcProtocolMessage.Handshake.serializer().descriptor.serialName}\".*+".toRegex() + + @Test + fun transportInitializedOnlyOnce() = runTest { logs -> + val localTransport = LocalTransport() + var transportInitialized = 0 + var configInitialized = 0 + val client = object : KrpcClient() { + override suspend fun initializeTransport(): KrpcTransport { + transportInitialized++ + return localTransport.client + } + + override fun initializeConfig(): KrpcConfig.Client { + configInitialized++ + return clientConfig + } + } + + val server = serverOf(localTransport) + + server.registerServiceAndReturn { EchoImpl() } + server.registerServiceAndReturn { SecondServer() } + + client.withService().apply { echo("foo"); echo("bar") } + client.withService().apply{ second("bar"); second("baz") } + + assertEquals(1, transportInitialized) + assertEquals(1, configInitialized) + assertEquals(1, logs.count { it.matches(clientHandshake) }) + } + private inline fun <@Rpc reified Service : Any, reified Impl : Service> RpcServer.registerServiceAndReturn( crossinline body: () -> Impl, ): List {