@@ -11,6 +11,7 @@ import kotlinx.coroutines.flow.FlowCollector
11
11
import kotlinx.coroutines.flow.MutableSharedFlow
12
12
import kotlinx.coroutines.flow.MutableStateFlow
13
13
import kotlinx.coroutines.flow.flow
14
+ import kotlinx.coroutines.selects.select
14
15
import kotlinx.rpc.RPCConfig
15
16
import kotlinx.rpc.StreamScope
16
17
import kotlinx.rpc.internal.map.ConcurrentHashMap
@@ -55,9 +56,10 @@ public class RPCStreamContext(
55
56
private companion object {
56
57
private const val STREAM_ID_PREFIX = " stream:"
57
58
}
59
+ private val closed = CompletableDeferred <Unit >()
58
60
59
61
// thread-safe set
60
- private val closedStreams = ConcurrentHashMap <String , Unit >()
62
+ private val closedStreams = ConcurrentHashMap <String , CompletableDeferred < Unit > >()
61
63
62
64
@InternalRPCApi
63
65
public inline fun launchIf (
@@ -168,7 +170,7 @@ public class RPCStreamContext(
168
170
fun onClose () {
169
171
incoming.cancel()
170
172
171
- closedStreams.put( streamId, Unit )
173
+ closedStreams[ streamId] = Unit
172
174
incomingChannels.remove(streamId)?.complete(null )
173
175
incomingStreams.remove(streamId)
174
176
}
@@ -242,27 +244,31 @@ public class RPCStreamContext(
242
244
}
243
245
244
246
public suspend fun send (message : RPCCallMessage .StreamMessage , serialFormat : SerialFormat ) {
245
- val info = incomingStreams.getDeferred(message.streamId).await()
247
+ val info: RPCStreamCall ? = select {
248
+ incomingStreams.getDeferred(message.streamId).onAwait { it }
249
+ closedStreams.getDeferred(message.streamId).onAwait { null }
250
+ closed.onAwait { null }
251
+ }
252
+ if (info == null ) return
246
253
val result = decodeMessageData(serialFormat, info.elementSerializer, message)
247
- incomingChannelOf(message.streamId)?.send(result)
254
+ val channel = incomingChannelOf(message.streamId)
255
+ channel?.send(result)
248
256
}
249
257
250
258
private suspend fun incomingChannelOf (streamId : String ): Channel <Any ?>? {
251
- if (closedStreams.containsKey(streamId)) {
252
- return null
259
+ return select {
260
+ incomingChannels.getDeferred(streamId).onAwait { it }
261
+ closedStreams.getDeferred(streamId).onAwait { null }
262
+ closed.onAwait { null }
253
263
}
254
-
255
- return incomingChannels.getDeferred(streamId).await()
256
264
}
257
265
258
- private var closed = false
259
-
260
266
private fun close (cause : Throwable ? ) {
261
- if (closed) {
267
+ if (closed.isCompleted ) {
262
268
return
263
269
}
264
270
265
- closed = true
271
+ closed.complete( Unit )
266
272
267
273
if (incomingChannelsInitialized) {
268
274
for (channel in incomingChannels.values) {
0 commit comments