Skip to content

Commit 45de74d

Browse files
pikinier20Mr3zee
andauthored
KRPC-101 Check if the entire stream is not already closed. (#158)
* KRPC-101 Check if the entire stream is not already closed. In such case, the incomingChannels get cleared and closedStreams don't contain the streamId which leads to deadlock * Fix import and yarn.lock --------- Co-authored-by: Alexander Sysoev <[email protected]>
1 parent 582aa99 commit 45de74d

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

core/src/commonMain/kotlin/kotlinx/rpc/internal/RPCStreamContext.kt

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import kotlinx.coroutines.flow.FlowCollector
1111
import kotlinx.coroutines.flow.MutableSharedFlow
1212
import kotlinx.coroutines.flow.MutableStateFlow
1313
import kotlinx.coroutines.flow.flow
14+
import kotlinx.coroutines.selects.select
1415
import kotlinx.rpc.RPCConfig
1516
import kotlinx.rpc.StreamScope
1617
import kotlinx.rpc.internal.map.ConcurrentHashMap
@@ -55,9 +56,10 @@ public class RPCStreamContext(
5556
private companion object {
5657
private const val STREAM_ID_PREFIX = "stream:"
5758
}
59+
private val closed = CompletableDeferred<Unit>()
5860

5961
// thread-safe set
60-
private val closedStreams = ConcurrentHashMap<String, Unit>()
62+
private val closedStreams = ConcurrentHashMap<String, CompletableDeferred<Unit>>()
6163

6264
@InternalRPCApi
6365
public inline fun launchIf(
@@ -168,7 +170,7 @@ public class RPCStreamContext(
168170
fun onClose() {
169171
incoming.cancel()
170172

171-
closedStreams.put(streamId, Unit)
173+
closedStreams[streamId] = Unit
172174
incomingChannels.remove(streamId)?.complete(null)
173175
incomingStreams.remove(streamId)
174176
}
@@ -242,27 +244,31 @@ public class RPCStreamContext(
242244
}
243245

244246
public suspend fun send(message: RPCCallMessage.StreamMessage, serialFormat: SerialFormat) {
245-
val info = incomingStreams.getDeferred(message.streamId).await()
247+
val info: RPCStreamCall? = select {
248+
incomingStreams.getDeferred(message.streamId).onAwait { it }
249+
closedStreams.getDeferred(message.streamId).onAwait { null }
250+
closed.onAwait { null }
251+
}
252+
if (info == null) return
246253
val result = decodeMessageData(serialFormat, info.elementSerializer, message)
247-
incomingChannelOf(message.streamId)?.send(result)
254+
val channel = incomingChannelOf(message.streamId)
255+
channel?.send(result)
248256
}
249257

250258
private suspend fun incomingChannelOf(streamId: String): Channel<Any?>? {
251-
if (closedStreams.containsKey(streamId)) {
252-
return null
259+
return select {
260+
incomingChannels.getDeferred(streamId).onAwait { it }
261+
closedStreams.getDeferred(streamId).onAwait { null }
262+
closed.onAwait { null }
253263
}
254-
255-
return incomingChannels.getDeferred(streamId).await()
256264
}
257265

258-
private var closed = false
259-
260266
private fun close(cause: Throwable?) {
261-
if (closed) {
267+
if (closed.isCompleted) {
262268
return
263269
}
264270

265-
closed = true
271+
closed.complete(Unit)
266272

267273
if (incomingChannelsInitialized) {
268274
for (channel in incomingChannels.values) {

0 commit comments

Comments
 (0)