diff --git a/krpc/krpc-core/api/krpc-core.api b/krpc/krpc-core/api/krpc-core.api index da6690c9a..06470faf7 100644 --- a/krpc/krpc-core/api/krpc-core.api +++ b/krpc/krpc-core/api/krpc-core.api @@ -77,9 +77,15 @@ public final class kotlinx/rpc/krpc/RPCTransportMessage$StringMessage : kotlinx/ public final fun getValue ()Ljava/lang/String; } +public final class kotlinx/rpc/krpc/StreamScope : kotlinx/rpc/internal/utils/AutoCloseable { + public fun close ()V +} + public final class kotlinx/rpc/krpc/StreamScopeKt { + public static final fun StreamScope (Lkotlin/coroutines/CoroutineContext;)Lkotlinx/rpc/krpc/StreamScope; public static final fun invokeOnStreamScopeCompletion (ZLkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun invokeOnStreamScopeCompletion$default (ZLkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public static final fun streamScoped (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun withStreamScope (Lkotlinx/rpc/krpc/StreamScope;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/StreamScope.kt b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/StreamScope.kt index 4ebe10bb2..fa8f79558 100644 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/StreamScope.kt +++ b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/StreamScope.kt @@ -15,6 +15,7 @@ import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.coroutines.CoroutineContext import kotlin.coroutines.coroutineContext +import kotlin.js.JsName /** * Stream scope handles all RPC streams that are launched inside it. @@ -26,33 +27,46 @@ import kotlin.coroutines.coroutineContext * Stream scope is a child of the [CoroutineContext] it was created in. * Failure of one request will not cancel all streams in the others. */ -@InternalRPCApi @OptIn(InternalCoroutinesApi::class) -public class StreamScope( +public class StreamScope internal constructor( parentContext: CoroutineContext, internal val role: Role, -) : CoroutineContext.Element, AutoCloseable { - internal companion object Key : CoroutineContext.Key +): AutoCloseable { + internal class Element(internal val scope: StreamScope) : CoroutineContext.Element { + override val key: CoroutineContext.Key = Key + + internal companion object Key : CoroutineContext.Key + } - override val key: CoroutineContext.Key = Key + internal val contextElement = Element(this) private val scopeJob = SupervisorJob(parentContext.job) private val requests = ConcurrentHashMap() + init { + scopeJob.invokeOnCompletion { + close() + } + } + + @InternalRPCApi public fun onScopeCompletion(handler: (Throwable?) -> Unit) { scopeJob.invokeOnCompletion(handler) } + @InternalRPCApi public fun onScopeCompletion(callId: String, handler: (Throwable?) -> Unit) { getRequestScope(callId).coroutineContext.job.invokeOnCompletion(onCancelling = true, handler = handler) } + @InternalRPCApi public fun cancelRequestScopeById(callId: String, message: String, cause: Throwable?): Job? { return requests.remove(callId)?.apply { cancel(message, cause) }?.coroutineContext?.job } // Group stream launches by callId. In case one fails, so do others + @InternalRPCApi public fun launch(callId: String, block: suspend CoroutineScope.() -> Unit): Job { return getRequestScope(callId).launch(block = block) } @@ -86,19 +100,19 @@ public fun CoroutineContext.withServerStreamScope(): CoroutineContext = withStre @OptIn(InternalCoroutinesApi::class) internal fun CoroutineContext.withStreamScope(role: StreamScope.Role): CoroutineContext { - return this + StreamScope(this, role).apply { - this@withStreamScope.job.invokeOnCompletion(onCancelling = true) { close() } + return this + StreamScope(this, role).contextElement.apply { + this@withStreamScope.job.invokeOnCompletion(onCancelling = true) { scope.close() } } } @InternalRPCApi public suspend fun streamScopeOrNull(): StreamScope? { - return currentCoroutineContext()[StreamScope.Key] + return currentCoroutineContext()[StreamScope.Element.Key]?.scope } @InternalRPCApi public fun streamScopeOrNull(scope: CoroutineScope): StreamScope? { - return scope.coroutineContext[StreamScope.Key] + return scope.coroutineContext[StreamScope.Element.Key]?.scope } internal fun noStreamScopeError(): Nothing { @@ -165,22 +179,53 @@ public suspend fun streamScoped(block: suspend CoroutineScope.() -> T): T { } val context = currentCoroutineContext() + .apply { + checkContextForStreamScope() + } + + val streamScope = StreamScope(context, StreamScope.Role.Client) + + return withContext(streamScope.contextElement) { + streamScope.use { + block() + } + } +} - if (context[StreamScope.Key] != null) { +private fun CoroutineContext.checkContextForStreamScope() { + if (this[StreamScope.Element] != null) { error( "One of the following caused a failure: \n" + - "- nested 'streamScoped' calls are not allowed.\n" + - "- 'streamScoped' calls are not allowed in server RPC services." + "- nested 'streamScoped' or `withStreamScope` calls are not allowed.\n" + + "- 'streamScoped' or `withStreamScope` calls are not allowed in server RPC services." ) } +} - val streamScope = StreamScope(context, StreamScope.Role.Client) +/** + * Creates a [StreamScope] entity for manual stream management. + */ +@JsName("StreamScope_fun") +@ExperimentalRPCApi +public fun StreamScope(parent: CoroutineContext): StreamScope { + parent.checkContextForStreamScope() - return withContext(streamScope) { - streamScope.use { - block() - } + return StreamScope(parent, StreamScope.Role.Client) +} + +/** + * Adds manually managed [StreamScope] to the current context. + */ +@OptIn(ExperimentalContracts::class) +@ExperimentalRPCApi +public suspend fun withStreamScope(scope: StreamScope, block: suspend CoroutineScope.() -> T): T { + contract { + callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + + currentCoroutineContext().checkContextForStreamScope() + + return withContext(scope.contextElement, block) } /** diff --git a/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt index ac349fbab..ca1d9b6ec 100644 --- a/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt +++ b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt @@ -8,9 +8,11 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.toList +import kotlinx.rpc.krpc.StreamScope import kotlinx.rpc.krpc.internal.STREAM_SCOPES_ENABLED import kotlinx.rpc.krpc.invokeOnStreamScopeCompletion import kotlinx.rpc.krpc.streamScoped +import kotlinx.rpc.krpc.withStreamScope import kotlinx.rpc.withService import kotlin.test.* @@ -586,6 +588,55 @@ class CancellationTest { stopAllAndJoin() } + @Test + fun manualStreamScopeNoCancel() = runCancellationTest { + val myJob = Job() + val streamScope = StreamScope(myJob) + + val unrelatedJob = Job() + + var first: Int = -1 + val deferredFlow = CoroutineScope(unrelatedJob).async { + withStreamScope(streamScope) { + service.incomingStream().apply { first = first() } + } + } + val flow= deferredFlow.await() + + serverInstance().fence.complete(Unit) + val consumed = flow.toList() + + assertEquals(0, first) + assertContentEquals(listOf(1), consumed) + + stopAllAndJoin() + } + + @Test + fun manualStreamScopeWithCancel() = runCancellationTest { + val myJob = Job() + val streamScope = StreamScope(myJob) + + val unrelatedJob = Job() + + var first: Int = -1 + val deferredFlow = CoroutineScope(unrelatedJob).async { + withStreamScope(streamScope) { + service.incomingStream().apply { first = first() } + } + } + val flow= deferredFlow.await() + + streamScope.close() + serverInstance().fence.complete(Unit) + val consumed = flow.toList() + + assertEquals(0, first) + assertContentEquals(emptyList(), consumed) + + stopAllAndJoin() + } + @Test fun testCancelledClientCancelsRequest() = runCancellationTest { launch {