@@ -56,6 +56,9 @@ public class RPCStreamContext(
56
56
private const val STREAM_ID_PREFIX = " stream:"
57
57
}
58
58
59
+ // thread-safe set
60
+ private val closedStreams = ConcurrentHashMap <String , Unit >()
61
+
59
62
@InternalRPCApi
60
63
public inline fun launchIf (
61
64
condition : RPCStreamContext .() -> Boolean ,
@@ -97,7 +100,7 @@ public class RPCStreamContext(
97
100
private var incomingChannelsInitialized: Boolean = false
98
101
private val incomingChannels by lazy {
99
102
incomingChannelsInitialized = true
100
- ConcurrentHashMap <String , CompletableDeferred <Channel <Any ?>>>()
103
+ ConcurrentHashMap <String , CompletableDeferred <Channel <Any ?>? >> ()
101
104
}
102
105
103
106
private var outgoingStreamsInitialized: Boolean = false
@@ -165,7 +168,8 @@ public class RPCStreamContext(
165
168
fun onClose () {
166
169
incoming.cancel()
167
170
168
- incomingChannels.remove(streamId)
171
+ closedStreams.put(streamId, Unit )
172
+ incomingChannels.remove(streamId)?.complete(null )
169
173
incomingStreams.remove(streamId)
170
174
}
171
175
@@ -230,20 +234,24 @@ public class RPCStreamContext(
230
234
}
231
235
232
236
public suspend fun closeStream (message : RPCCallMessage .StreamFinished ) {
233
- incomingChannelOf(message.streamId).send(StreamEnd )
237
+ incomingChannelOf(message.streamId)? .send(StreamEnd )
234
238
}
235
239
236
240
public suspend fun cancelStream (message : RPCCallMessage .StreamCancel ) {
237
- incomingChannelOf(message.streamId).send(StreamCancel (message.cause.deserialize()))
241
+ incomingChannelOf(message.streamId)? .send(StreamCancel (message.cause.deserialize()))
238
242
}
239
243
240
244
public suspend fun send (message : RPCCallMessage .StreamMessage , serialFormat : SerialFormat ) {
241
245
val info = incomingStreams.getDeferred(message.streamId).await()
242
246
val result = decodeMessageData(serialFormat, info.elementSerializer, message)
243
- incomingChannelOf(message.streamId).send(result)
247
+ incomingChannelOf(message.streamId)? .send(result)
244
248
}
245
249
246
- private suspend fun incomingChannelOf (streamId : String ): Channel <Any ?> {
250
+ private suspend fun incomingChannelOf (streamId : String ): Channel <Any ?>? {
251
+ if (closedStreams.containsKey(streamId)) {
252
+ return null
253
+ }
254
+
247
255
return incomingChannels.getDeferred(streamId).await()
248
256
}
249
257
@@ -263,7 +271,7 @@ public class RPCStreamContext(
263
271
}
264
272
265
273
@OptIn(ExperimentalCoroutinesApi ::class )
266
- channel.getCompleted().apply {
274
+ channel.getCompleted()? .apply {
267
275
trySend(StreamEnd )
268
276
269
277
// close for sending, but not for receiving our cancel message, if possible.
0 commit comments