Skip to content

Commit 1110043

Browse files
authored
KRPC-97 Race condition in stream cancellation locks the transport (#138)
1 parent abed7eb commit 1110043

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

core/src/commonMain/kotlin/kotlinx/rpc/internal/RPCStreamContext.kt

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ public class RPCStreamContext(
5656
private const val STREAM_ID_PREFIX = "stream:"
5757
}
5858

59+
// thread-safe set
60+
private val closedStreams = ConcurrentHashMap<String, Unit>()
61+
5962
@InternalRPCApi
6063
public inline fun launchIf(
6164
condition: RPCStreamContext.() -> Boolean,
@@ -97,7 +100,7 @@ public class RPCStreamContext(
97100
private var incomingChannelsInitialized: Boolean = false
98101
private val incomingChannels by lazy {
99102
incomingChannelsInitialized = true
100-
ConcurrentHashMap<String, CompletableDeferred<Channel<Any?>>>()
103+
ConcurrentHashMap<String, CompletableDeferred<Channel<Any?>?>>()
101104
}
102105

103106
private var outgoingStreamsInitialized: Boolean = false
@@ -165,7 +168,8 @@ public class RPCStreamContext(
165168
fun onClose() {
166169
incoming.cancel()
167170

168-
incomingChannels.remove(streamId)
171+
closedStreams.put(streamId, Unit)
172+
incomingChannels.remove(streamId)?.complete(null)
169173
incomingStreams.remove(streamId)
170174
}
171175

@@ -230,20 +234,24 @@ public class RPCStreamContext(
230234
}
231235

232236
public suspend fun closeStream(message: RPCCallMessage.StreamFinished) {
233-
incomingChannelOf(message.streamId).send(StreamEnd)
237+
incomingChannelOf(message.streamId)?.send(StreamEnd)
234238
}
235239

236240
public suspend fun cancelStream(message: RPCCallMessage.StreamCancel) {
237-
incomingChannelOf(message.streamId).send(StreamCancel(message.cause.deserialize()))
241+
incomingChannelOf(message.streamId)?.send(StreamCancel(message.cause.deserialize()))
238242
}
239243

240244
public suspend fun send(message: RPCCallMessage.StreamMessage, serialFormat: SerialFormat) {
241245
val info = incomingStreams.getDeferred(message.streamId).await()
242246
val result = decodeMessageData(serialFormat, info.elementSerializer, message)
243-
incomingChannelOf(message.streamId).send(result)
247+
incomingChannelOf(message.streamId)?.send(result)
244248
}
245249

246-
private suspend fun incomingChannelOf(streamId: String): Channel<Any?> {
250+
private suspend fun incomingChannelOf(streamId: String): Channel<Any?>? {
251+
if (closedStreams.containsKey(streamId)) {
252+
return null
253+
}
254+
247255
return incomingChannels.getDeferred(streamId).await()
248256
}
249257

@@ -263,7 +271,7 @@ public class RPCStreamContext(
263271
}
264272

265273
@OptIn(ExperimentalCoroutinesApi::class)
266-
channel.getCompleted().apply {
274+
channel.getCompleted()?.apply {
267275
trySend(StreamEnd)
268276

269277
// close for sending, but not for receiving our cancel message, if possible.

0 commit comments

Comments
 (0)