Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions krpc/krpc-core/api/krpc-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,15 @@ public final class kotlinx/rpc/krpc/RPCTransportMessage$StringMessage : kotlinx/
public final fun getValue ()Ljava/lang/String;
}

public final class kotlinx/rpc/krpc/StreamScope : kotlinx/rpc/internal/utils/AutoCloseable {
public fun close ()V
}

public final class kotlinx/rpc/krpc/StreamScopeKt {
public static final fun StreamScope (Lkotlin/coroutines/CoroutineContext;)Lkotlinx/rpc/krpc/StreamScope;
public static final fun invokeOnStreamScopeCompletion (ZLkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun invokeOnStreamScopeCompletion$default (ZLkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static final fun streamScoped (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun withStreamScope (Lkotlinx/rpc/krpc/StreamScope;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
import kotlin.js.JsName

/**
* Stream scope handles all RPC streams that are launched inside it.
Expand All @@ -26,33 +27,46 @@ import kotlin.coroutines.coroutineContext
* Stream scope is a child of the [CoroutineContext] it was created in.
* Failure of one request will not cancel all streams in the others.
*/
@InternalRPCApi
@OptIn(InternalCoroutinesApi::class)
public class StreamScope(
public class StreamScope internal constructor(
parentContext: CoroutineContext,
internal val role: Role,
) : CoroutineContext.Element, AutoCloseable {
internal companion object Key : CoroutineContext.Key<StreamScope>
): AutoCloseable {
internal class Element(internal val scope: StreamScope) : CoroutineContext.Element {
override val key: CoroutineContext.Key<Element> = Key

internal companion object Key : CoroutineContext.Key<Element>
}

override val key: CoroutineContext.Key<StreamScope> = Key
internal val contextElement = Element(this)

private val scopeJob = SupervisorJob(parentContext.job)

private val requests = ConcurrentHashMap<String, CoroutineScope>()

init {
scopeJob.invokeOnCompletion {
close()
}
}

@InternalRPCApi
public fun onScopeCompletion(handler: (Throwable?) -> Unit) {
scopeJob.invokeOnCompletion(handler)
}

@InternalRPCApi
public fun onScopeCompletion(callId: String, handler: (Throwable?) -> Unit) {
getRequestScope(callId).coroutineContext.job.invokeOnCompletion(onCancelling = true, handler = handler)
}

@InternalRPCApi
public fun cancelRequestScopeById(callId: String, message: String, cause: Throwable?): Job? {
return requests.remove(callId)?.apply { cancel(message, cause) }?.coroutineContext?.job
}

// Group stream launches by callId. In case one fails, so do others
@InternalRPCApi
public fun launch(callId: String, block: suspend CoroutineScope.() -> Unit): Job {
return getRequestScope(callId).launch(block = block)
}
Expand Down Expand Up @@ -86,19 +100,19 @@ public fun CoroutineContext.withServerStreamScope(): CoroutineContext = withStre

@OptIn(InternalCoroutinesApi::class)
internal fun CoroutineContext.withStreamScope(role: StreamScope.Role): CoroutineContext {
return this + StreamScope(this, role).apply {
[email protected](onCancelling = true) { close() }
return this + StreamScope(this, role).contextElement.apply {
[email protected](onCancelling = true) { scope.close() }
}
}

@InternalRPCApi
public suspend fun streamScopeOrNull(): StreamScope? {
return currentCoroutineContext()[StreamScope.Key]
return currentCoroutineContext()[StreamScope.Element.Key]?.scope
}

@InternalRPCApi
public fun streamScopeOrNull(scope: CoroutineScope): StreamScope? {
return scope.coroutineContext[StreamScope.Key]
return scope.coroutineContext[StreamScope.Element.Key]?.scope
}

internal fun noStreamScopeError(): Nothing {
Expand Down Expand Up @@ -165,22 +179,53 @@ public suspend fun <T> streamScoped(block: suspend CoroutineScope.() -> T): T {
}

val context = currentCoroutineContext()
.apply {
checkContextForStreamScope()
}

val streamScope = StreamScope(context, StreamScope.Role.Client)

return withContext(streamScope.contextElement) {
streamScope.use {
block()
}
}
}

if (context[StreamScope.Key] != null) {
private fun CoroutineContext.checkContextForStreamScope() {
if (this[StreamScope.Element] != null) {
error(
"One of the following caused a failure: \n" +
"- nested 'streamScoped' calls are not allowed.\n" +
"- 'streamScoped' calls are not allowed in server RPC services."
"- nested 'streamScoped' or `withStreamScope` calls are not allowed.\n" +
"- 'streamScoped' or `withStreamScope` calls are not allowed in server RPC services."
)
}
}

val streamScope = StreamScope(context, StreamScope.Role.Client)
/**
* Creates a [StreamScope] entity for manual stream management.
*/
@JsName("StreamScope_fun")
@ExperimentalRPCApi
public fun StreamScope(parent: CoroutineContext): StreamScope {
parent.checkContextForStreamScope()

return withContext(streamScope) {
streamScope.use {
block()
}
return StreamScope(parent, StreamScope.Role.Client)
}

/**
* Adds manually managed [StreamScope] to the current context.
*/
@OptIn(ExperimentalContracts::class)
@ExperimentalRPCApi
public suspend fun <T> withStreamScope(scope: StreamScope, block: suspend CoroutineScope.() -> T): T {
contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
}

currentCoroutineContext().checkContextForStreamScope()

return withContext(scope.contextElement, block)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.toList
import kotlinx.rpc.krpc.StreamScope
import kotlinx.rpc.krpc.internal.STREAM_SCOPES_ENABLED
import kotlinx.rpc.krpc.invokeOnStreamScopeCompletion
import kotlinx.rpc.krpc.streamScoped
import kotlinx.rpc.krpc.withStreamScope
import kotlinx.rpc.withService
import kotlin.test.*

Expand Down Expand Up @@ -586,6 +588,55 @@ class CancellationTest {
stopAllAndJoin()
}

@Test
fun manualStreamScopeNoCancel() = runCancellationTest {
val myJob = Job()
val streamScope = StreamScope(myJob)

val unrelatedJob = Job()

var first: Int = -1
val deferredFlow = CoroutineScope(unrelatedJob).async {
withStreamScope(streamScope) {
service.incomingStream().apply { first = first() }
}
}
val flow= deferredFlow.await()

serverInstance().fence.complete(Unit)
val consumed = flow.toList()

assertEquals(0, first)
assertContentEquals(listOf(1), consumed)

stopAllAndJoin()
}

@Test
fun manualStreamScopeWithCancel() = runCancellationTest {
val myJob = Job()
val streamScope = StreamScope(myJob)

val unrelatedJob = Job()

var first: Int = -1
val deferredFlow = CoroutineScope(unrelatedJob).async {
withStreamScope(streamScope) {
service.incomingStream().apply { first = first() }
}
}
val flow= deferredFlow.await()

streamScope.close()
serverInstance().fence.complete(Unit)
val consumed = flow.toList()

assertEquals(0, first)
assertContentEquals(emptyList(), consumed)

stopAllAndJoin()
}

@Test
fun testCancelledClientCancelsRequest() = runCancellationTest {
launch {
Expand Down