Skip to content

Commit 1a46f54

Browse files
committed
Fixed cancellation handing for client and server kRPC
1 parent d36ae98 commit 1a46f54

File tree

6 files changed

+104
-75
lines changed

6 files changed

+104
-75
lines changed

krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import kotlinx.coroutines.cancel
1717
import kotlinx.coroutines.cancelAndJoin
1818
import kotlinx.coroutines.channels.Channel
1919
import kotlinx.coroutines.coroutineScope
20-
import kotlinx.coroutines.currentCoroutineContext
2120
import kotlinx.coroutines.ensureActive
2221
import kotlinx.coroutines.flow.Flow
2322
import kotlinx.coroutines.flow.FlowCollector
@@ -137,6 +136,9 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
137136
@Volatile
138137
private var clientCancelled = false
139138

139+
@Volatile
140+
private var clientCancelledByServer = false
141+
140142
private fun checkTransportReadiness() {
141143
if (!isTransportReady) {
142144
error(
@@ -153,9 +155,15 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
153155
val context = SupervisorJob(transport.coroutineContext.job)
154156

155157
context.job.invokeOnCompletion(onCancelling = true) {
158+
if (clientCancelled) {
159+
return@invokeOnCompletion
160+
}
161+
156162
clientCancelled = true
157163

158-
sendCancellation(CancellationType.ENDPOINT, null, null, closeTransportAfterSending = true)
164+
if (!clientCancelledByServer) {
165+
sendCancellation(CancellationType.ENDPOINT, null, null, closeTransportAfterSending = true)
166+
}
159167

160168
@OptIn(DelicateCoroutinesApi::class)
161169
@Suppress("detekt.GlobalCoroutineUsage")
@@ -255,7 +263,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
255263
final override fun <T> callServerStreaming(call: RpcCall): Flow<T> {
256264
return flow {
257265
if (clientCancelled) {
258-
error("Client cancelled")
266+
error("RpcClient was cancelled")
259267
}
260268

261269
initializeAndAwaitHandshakeCompletion()
@@ -271,6 +279,10 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
271279
try {
272280
@Suppress("UNCHECKED_CAST")
273281
requestChannels[callId] = channel as Channel<Result<Any?>>
282+
if (clientCancelled) {
283+
requestChannels.remove(callId)
284+
error("RpcClient was cancelled")
285+
}
274286

275287
val request = serializeRequest(
276288
callId = callId,
@@ -362,6 +374,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
362374
is KrpcCallMessage.CallException -> {
363375
val cause = message.cause.deserialize()
364376
channel.close(cause)
377+
channel.cancel(CancellationException("Call failed", cause))
365378
}
366379

367380
is KrpcCallMessage.CallSuccess, is KrpcCallMessage.StreamMessage -> {
@@ -383,6 +396,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
383396
is KrpcCallMessage.StreamCancel -> {
384397
val cause = message.cause.deserialize()
385398
channel.close(cause)
399+
channel.cancel(CancellationException("Stream cancelled", cause))
386400
}
387401
}
388402
}
@@ -391,6 +405,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
391405
final override suspend fun handleCancellation(message: KrpcGenericMessage) {
392406
when (val type = message.cancellationType()) {
393407
CancellationType.ENDPOINT -> {
408+
clientCancelledByServer = true
394409
internalScope.cancel("Closing client after server cancellation") // we cancel this client
395410
}
396411

@@ -458,6 +473,7 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
458473
serialFormat: SerialFormat,
459474
serviceTypeString: String,
460475
) {
476+
var failure: Throwable? = null
461477
try {
462478
collectAndSendOutgoingStream(
463479
serialFormat = serialFormat,
@@ -466,39 +482,31 @@ public abstract class KrpcClient : RpcClient, KrpcEndpoint {
466482
serviceTypeString = serviceTypeString,
467483
)
468484
} catch (e: CancellationException) {
469-
currentCoroutineContext().ensureActive()
470-
471-
val wrapped = ManualCancellationException(e)
472-
val serializedReason = serializeException(wrapped)
473-
val message = KrpcCallMessage.StreamCancel(
474-
callId = outgoingStream.callId,
475-
serviceType = serviceTypeString,
476-
streamId = outgoingStream.streamId,
477-
cause = serializedReason,
478-
connectionId = outgoingStream.connectionId,
479-
serviceId = outgoingStream.serviceId,
480-
)
481-
connector.sendMessageChecked(message) {
482-
// ignore, we are already cancelled and have a cause
483-
}
485+
internalScope.ensureActive()
486+
487+
failure = ManualCancellationException(e)
484488

485489
// stop the flow and its coroutine, other flows are not affected
486490
throw e
487491
} catch (cause: Throwable) {
488-
val serializedReason = serializeException(cause)
489-
val message = KrpcCallMessage.StreamCancel(
490-
callId = outgoingStream.callId,
491-
serviceType = serviceTypeString,
492-
streamId = outgoingStream.streamId,
493-
cause = serializedReason,
494-
connectionId = outgoingStream.connectionId,
495-
serviceId = outgoingStream.serviceId,
496-
)
497-
connector.sendMessageChecked(message) {
498-
// ignore, we are already cancelled and have a cause
499-
}
492+
failure = cause
500493

501494
throw cause
495+
} finally {
496+
if (failure != null) {
497+
val serializedReason = serializeException(failure)
498+
val message = KrpcCallMessage.StreamCancel(
499+
callId = outgoingStream.callId,
500+
serviceType = serviceTypeString,
501+
streamId = outgoingStream.streamId,
502+
cause = serializedReason,
503+
connectionId = outgoingStream.connectionId,
504+
serviceId = outgoingStream.serviceId,
505+
)
506+
connector.sendMessageChecked(message) {
507+
// ignore, we are already cancelled and have a cause
508+
}
509+
}
502510
}
503511

504512
val message = KrpcCallMessage.StreamFinished(

krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/KrpcServerService.kt

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ internal class KrpcServerService<@Rpc T : Any>(
141141
var failure: Throwable? = null
142142

143143
val requestJob = serverScope.launch(start = CoroutineStart.LAZY) {
144+
var startedCollecting = false
144145
try {
145146
val markedNonSuspending = callData.pluginParams.orEmpty()
146147
.contains(KrpcPluginKey.NON_SUSPENDING_SERVER_FLOW_MARKER)
@@ -180,31 +181,39 @@ internal class KrpcServerService<@Rpc T : Any>(
180181
)
181182
}
182183

184+
startedCollecting = true
183185
sendFlowMessages(serialFormat, returnSerializer, value, callData)
184186
} else {
185187
sendMessageValue(serialFormat, returnSerializer, value, callData)
186188
}
187189
} catch (cause: CancellationException) {
188-
currentCoroutineContext().ensureActive()
190+
serverScope.ensureActive()
191+
val request = requestMap[callId]
192+
if (request == null || request.serviceCancelled) {
193+
throw cause
194+
}
189195

190-
val wrapped = ManualCancellationException(cause)
196+
failure = ManualCancellationException(cause)
191197

192-
failure = wrapped
198+
throw cause
193199
} catch (cause: Throwable) {
194200
failure = cause
195201
} finally {
196202
if (failure != null) {
197-
val serializedCause = serializeException(failure)
198-
val exceptionMessage = KrpcCallMessage.CallException(
199-
callId = callId,
200-
serviceType = descriptor.fqName,
201-
cause = serializedCause,
202-
connectionId = callData.connectionId,
203-
serviceId = callData.serviceId,
204-
)
203+
// flow cancellations are handled by the sendFlowMessages function
204+
if (!startedCollecting || !callable.isNonSuspendFunction) {
205+
val serializedCause = serializeException(failure)
206+
val exceptionMessage = KrpcCallMessage.CallException(
207+
callId = callId,
208+
serviceType = descriptor.fqName,
209+
cause = serializedCause,
210+
connectionId = callData.connectionId,
211+
serviceId = callData.serviceId,
212+
)
205213

206-
connector.sendMessageChecked(exceptionMessage) {
207-
// ignore, the client probably already disconnected
214+
connector.sendMessageChecked(exceptionMessage) {
215+
// ignore, the client probably already disconnected
216+
}
208217
}
209218

210219
closeReceiving(callId, "Server request failed", failure, fromJob = true)
@@ -268,6 +277,7 @@ internal class KrpcServerService<@Rpc T : Any>(
268277
flow: Flow<Any?>,
269278
callData: KrpcCallMessage.CallData,
270279
) {
280+
var failure: Throwable? = null
271281
try {
272282
flow.collect { value ->
273283
val result = when (serialFormat) {
@@ -315,20 +325,34 @@ internal class KrpcServerService<@Rpc T : Any>(
315325
// do nothing
316326
}
317327
} catch (cause: CancellationException) {
328+
serverScope.ensureActive()
329+
val request = requestMap[callData.callId]
330+
if (request == null || request.serviceCancelled) {
331+
throw cause
332+
}
333+
334+
failure = ManualCancellationException(cause)
335+
318336
throw cause
319337
} catch (cause: Throwable) {
320-
val serializedCause = serializeException(cause)
321-
connector.sendMessageChecked(
322-
KrpcCallMessage.StreamCancel(
323-
callId = callData.callId,
324-
serviceType = descriptor.fqName,
325-
connectionId = callData.connectionId,
326-
serviceId = callData.serviceId,
327-
streamId = SINGLE_STREAM_ID,
328-
cause = serializedCause,
329-
)
330-
) {
331-
// do nothing
338+
failure = cause
339+
340+
throw cause
341+
} finally {
342+
if (failure != null) {
343+
val serializedCause = serializeException(failure)
344+
connector.sendMessageChecked(
345+
KrpcCallMessage.StreamCancel(
346+
callId = callData.callId,
347+
serviceType = descriptor.fqName,
348+
connectionId = callData.connectionId,
349+
serviceId = callData.serviceId,
350+
streamId = SINGLE_STREAM_ID,
351+
cause = serializedCause,
352+
)
353+
) {
354+
// do nothing
355+
}
332356
}
333357
}
334358
}
@@ -393,12 +417,18 @@ internal class KrpcServerService<@Rpc T : Any>(
393417
}
394418

395419
internal class RpcRequest(val handlerJob: Job, val streamContext: ServerStreamContext) {
420+
// not user cancelled
421+
var serviceCancelled: Boolean = false
422+
private set
423+
396424
fun cancelAndClose(
397425
callId: String,
398426
message: String? = null,
399427
cause: Throwable? = null,
400428
fromJob: Boolean = false,
401429
) {
430+
serviceCancelled = true
431+
402432
if (!handlerJob.isCompleted && !fromJob) {
403433
when {
404434
message != null && cause != null -> handlerJob.cancel(message, cause)

krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamContext.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package kotlinx.rpc.krpc.server.internal
66

77
import kotlinx.atomicfu.atomic
8+
import kotlinx.coroutines.CancellationException
89
import kotlinx.coroutines.channels.Channel
910
import kotlinx.coroutines.flow.Flow
1011
import kotlinx.coroutines.flow.flow
@@ -49,6 +50,7 @@ internal class ServerStreamContext {
4950
fun removeCall(callId: String, cause: Throwable?) {
5051
streams.remove(callId)?.values?.forEach {
5152
it.channel.close(cause)
53+
it.channel.cancel(cause?.let { e -> e as? CancellationException ?: CancellationException(null, e) })
5254
}
5355
}
5456

krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTransportTestBase.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,19 @@ abstract class KrpcTransportTestBase {
113113

114114
@Test
115115
fun nonSuspendErrorOnEmit() = runTest {
116-
val flow = client.nonSuspendFlowErrorOnReturn()
117-
assertFails {
116+
val flow = client.nonSuspendFlowErrorOnEmit()
117+
val failure = assertFails {
118118
flow.toList()
119119
}
120+
assertFalse(failure is CancellationException)
120121
}
121122

122123
@Test
123124
fun nonSuspendErrorOnReturn() = runTest {
124-
assertFails {
125+
val failure = assertFails {
125126
client.nonSuspendFlowErrorOnReturn().toList()
126127
}
128+
assertFalse(failure is CancellationException)
127129
}
128130

129131
@Test

krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationService.kt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ interface CancellationService {
2121

2222
fun cancellationInIncomingStream(): Flow<Int>
2323

24-
suspend fun cancellationInOutgoingStream(stream: Flow<Int>, cancelled: Flow<Int>)
24+
suspend fun cancellationInOutgoingStream(cancelled: Flow<Int>)
2525

2626
suspend fun outgoingStream(stream: Flow<Int>)
2727

@@ -74,16 +74,12 @@ class CancellationServiceImpl : CancellationService {
7474
}
7575
}
7676

77-
override suspend fun cancellationInOutgoingStream(stream: Flow<Int>, cancelled: Flow<Int>) {
77+
override suspend fun cancellationInOutgoingStream(cancelled: Flow<Int>) {
7878
supervisorScope {
79-
launch {
80-
consume(stream)
81-
}
82-
8379
launch {
8480
try {
8581
cancelled.collect {
86-
if (it == 0) {
82+
if (it == 1) {
8783
firstIncomingConsumed.complete(it)
8884
}
8985
}

krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,7 @@ class CancellationTest {
125125
supervisorScope {
126126
val requestJob = launch {
127127
service.cancellationInOutgoingStream(
128-
stream = flow {
129-
emit(42)
130-
println("[testCancellationInClientStream] emit 42")
131-
emit(43)
132-
println("[testCancellationInClientStream] emit 43")
133-
},
134-
cancelled = flow {
128+
flow {
135129
emit(1)
136130
println("[testCancellationInClientStream] emit 1")
137131
serverInstance().firstIncomingConsumed.await()
@@ -143,11 +137,8 @@ class CancellationTest {
143137

144138
requestJob.join()
145139
println("[testCancellationInClientStream] Request job finished")
146-
serverInstance().consumedAll.await()
147-
println("[testCancellationInClientStream] Server consumed all")
148140

149141
assertFalse(requestJob.isCancelled, "Expected requestJob not to be cancelled")
150-
assertContentEquals(listOf(42, 43), serverInstance().consumedIncomingValues)
151142
}
152143
println("[testCancellationInClientStream] Scope finished")
153144

0 commit comments

Comments
 (0)