@@ -5,6 +5,7 @@ import com.apollographql.apollo3.api.http.HttpHeader
5
5
import com.apollographql.apollo3.exception.ApolloNetworkException
6
6
import com.apollographql.apollo3.exception.ApolloWebSocketClosedException
7
7
import io.ktor.client.HttpClient
8
+ import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
8
9
import io.ktor.client.plugins.websocket.WebSockets
9
10
import io.ktor.client.plugins.websocket.webSocket
10
11
import io.ktor.client.request.headers
@@ -14,14 +15,14 @@ import io.ktor.http.URLProtocol
14
15
import io.ktor.http.Url
15
16
import io.ktor.websocket.CloseReason
16
17
import io.ktor.websocket.Frame
17
- import io.ktor.websocket.close
18
18
import io.ktor.websocket.readText
19
19
import kotlinx.coroutines.CancellationException
20
20
import kotlinx.coroutines.CoroutineScope
21
- import kotlinx.coroutines.Deferred
22
21
import kotlinx.coroutines.Dispatchers
23
22
import kotlinx.coroutines.SupervisorJob
24
23
import kotlinx.coroutines.channels.Channel
24
+ import kotlinx.coroutines.channels.ReceiveChannel
25
+ import kotlinx.coroutines.coroutineScope
25
26
import kotlinx.coroutines.launch
26
27
import okio.ByteString
27
28
@@ -37,6 +38,8 @@ class KtorWebSocketEngine(
37
38
)
38
39
39
40
private val coroutineScope = CoroutineScope (Dispatchers .Default + SupervisorJob ())
41
+ private val receiveMessageChannel = Channel <String >(Channel .UNLIMITED )
42
+ private val sendFrameChannel = Channel <Frame >(Channel .UNLIMITED )
40
43
41
44
override suspend fun open (
42
45
url : String ,
@@ -52,8 +55,6 @@ class KtorWebSocketEngine(
52
55
/* URLProtocol.SOCKS */ else -> throw UnsupportedOperationException (" SOCKS is not a supported protocol" )
53
56
}
54
57
}.build()
55
- val receiveMessageChannel = Channel <String >(Channel .UNLIMITED )
56
- val sendFrameChannel = Channel <Frame >(Channel .UNLIMITED )
57
58
coroutineScope.launch {
58
59
try {
59
60
client.webSocket(
@@ -66,56 +67,36 @@ class KtorWebSocketEngine(
66
67
url(newUrl)
67
68
},
68
69
) {
69
- launch {
70
- while (true ) {
71
- val frame = sendFrameChannel.receive()
72
- try {
73
- send(frame)
74
-
75
- // Also close the connection if the sent frame is a close frame
76
- if (frame is Frame .Close ) {
77
- receiveMessageChannel.close()
78
- sendFrameChannel.close()
79
- break
80
- }
81
- } catch (e: Exception ) {
82
- handleNetworkException(e, closeReason, receiveMessageChannel, sendFrameChannel)
83
- break
84
- }
70
+ coroutineScope {
71
+ launch {
72
+ sendFrames(this @webSocket)
85
73
}
86
- }
87
- while (true ) {
88
- when (val frame = try {
89
- incoming.receive()
90
- } catch (e: Exception ) {
91
- handleNetworkException(e, closeReason, receiveMessageChannel, sendFrameChannel)
92
- break
93
- }) {
94
- is Frame .Text -> {
95
- receiveMessageChannel.send(frame.readText())
96
- }
97
-
98
- is Frame .Binary -> {
99
- receiveMessageChannel.send(frame.data.decodeToString())
74
+ try {
75
+ receiveFrames(incoming)
76
+ } catch (e: Throwable ) {
77
+ val closeReason = closeReasonOrNull()
78
+ val apolloException = if (closeReason != null ) {
79
+ ApolloWebSocketClosedException (
80
+ code = closeReason.code.toInt(),
81
+ reason = closeReason.message,
82
+ cause = e
83
+ )
84
+ } else {
85
+ ApolloNetworkException (
86
+ message = " Web socket communication error" ,
87
+ platformCause = e
88
+ )
100
89
}
101
-
102
- is Frame .Ping -> {
103
- send(Frame .Pong (frame.data))
104
- }
105
-
106
- is Frame .Pong -> {}
107
- is Frame .Close -> {
108
- close()
109
- receiveMessageChannel.close()
110
- }
111
-
112
- else -> error(" unknown frame type" )
90
+ receiveMessageChannel.close(apolloException)
91
+ throw e
113
92
}
114
93
}
115
94
}
116
- } catch (e: Exception ) {
95
+ } catch (e: Throwable ) {
117
96
receiveMessageChannel.close(ApolloNetworkException (message = " Web socket communication error" , platformCause = e))
118
- sendFrameChannel.close(e)
97
+ } finally {
98
+ // Not 100% sure this can happen. Better safe than sorry. close() is idempotent so it shouldn't hurt
99
+ receiveMessageChannel.close(ApolloNetworkException (message = " Web socket communication error" , platformCause = null ))
119
100
}
120
101
}
121
102
return object : WebSocketConnection {
@@ -137,31 +118,42 @@ class KtorWebSocketEngine(
137
118
}
138
119
}
139
120
140
- private suspend fun handleNetworkException (
141
- e : Exception ,
142
- deferredCloseReason : Deferred <CloseReason ?>,
143
- receiveMessageChannel : Channel <String >,
144
- sendFrameChannel : Channel <Frame >,
145
- ) {
146
- if (e is CancellationException ) throw e
147
- val closeReason = try {
148
- deferredCloseReason.await()
149
- } catch (e: Exception ) {
121
+ private suspend fun DefaultClientWebSocketSession.closeReasonOrNull (): CloseReason ? {
122
+ return try {
123
+ closeReason.await()
124
+ } catch (t: Throwable ) {
125
+ if (t is CancellationException ) {
126
+ throw t
127
+ }
150
128
null
151
129
}
152
- val apolloException = if (closeReason != null ) {
153
- ApolloWebSocketClosedException (
154
- code = closeReason.code.toInt(),
155
- reason = closeReason.message,
156
- cause = e
157
- )
158
- } else {
159
- ApolloNetworkException (
160
- message = " Web socket communication error" ,
161
- platformCause = e
162
- )
130
+ }
131
+
132
+ private suspend fun sendFrames (session : DefaultClientWebSocketSession ) {
133
+ while (true ) {
134
+ val frame = sendFrameChannel.receive()
135
+ session.send(frame)
136
+ if (frame is Frame .Close ) {
137
+ // normal termination
138
+ receiveMessageChannel.close()
139
+ }
140
+ }
141
+ }
142
+
143
+ private suspend fun receiveFrames (incoming : ReceiveChannel <Frame >) {
144
+ while (true ) {
145
+ val frame = incoming.receive()
146
+ when (frame) {
147
+ is Frame .Text -> {
148
+ receiveMessageChannel.trySend(frame.readText())
149
+ }
150
+
151
+ is Frame .Binary -> {
152
+ receiveMessageChannel.trySend(frame.data.decodeToString())
153
+ }
154
+
155
+ else -> error(" unknown frame type" )
156
+ }
163
157
}
164
- receiveMessageChannel.close(apolloException)
165
- sendFrameChannel.close(apolloException)
166
158
}
167
159
}
0 commit comments