@@ -8,6 +8,7 @@ package kotlinx.rpc.krpc.ktor
88
99import io.ktor.client.*
1010import io.ktor.client.engine.cio.*
11+ import io.ktor.client.plugins.HttpRequestRetry
1112import io.ktor.client.request.*
1213import io.ktor.client.statement.*
1314import io.ktor.server.application.*
@@ -17,9 +18,7 @@ import io.ktor.server.response.*
1718import io.ktor.server.routing.*
1819import io.ktor.server.testing.*
1920import kotlinx.coroutines.*
20- import kotlinx.coroutines.debug.DebugProbes
2121import kotlinx.rpc.annotations.Rpc
22- import kotlinx.rpc.krpc.client.KrpcClient
2322import kotlinx.rpc.krpc.internal.logging.RpcInternalCommonLogger
2423import kotlinx.rpc.krpc.internal.logging.RpcInternalDumpLoggerContainer
2524import kotlinx.rpc.krpc.internal.logging.dumpLogger
@@ -32,12 +31,12 @@ import kotlinx.rpc.krpc.serialization.json.json
3231import kotlinx.rpc.test.runTestWithCoroutinesProbes
3332import kotlinx.rpc.withService
3433import org.junit.Assert.assertEquals
34+ import org.junit.Assert.assertTrue
3535import java.net.ServerSocket
3636import java.util.concurrent.Executors
37- import java.util.concurrent.TimeUnit
3837import kotlin.coroutines.cancellation.CancellationException
39- import kotlin.test.Ignore
4038import kotlin.test.Test
39+ import kotlin.test.fail
4140import kotlin.time.Duration.Companion.seconds
4241
4342@Rpc
@@ -62,13 +61,14 @@ interface SlowService {
6261
6362class SlowServiceImpl : SlowService {
6463 val received = CompletableDeferred <Unit >()
64+ val fence = CompletableDeferred <Unit >()
6565
6666 override suspend fun verySlow (): String {
6767 received.complete(Unit )
6868
69- delay( Int . MAX_VALUE .toLong() )
69+ fence.await( )
7070
71- error( " Must not be called " )
71+ return " hello "
7272 }
7373}
7474
@@ -134,10 +134,7 @@ class KtorTransportTest {
134134
135135 @OptIn(DelicateCoroutinesApi ::class , ExperimentalCoroutinesApi ::class )
136136 @Test
137- @Ignore(" Wait for Ktor fix (https://github.com/ktorio/ktor/pull/4927) or apply workaround if rejected" )
138- fun testEndpointsTerminateWhenWsDoes () = runTestWithCoroutinesProbes(timeout = 15 .seconds) {
139- DebugProbes .install()
140-
137+ fun testClientTerminatesWhenServerWsDoes () = runTestWithCoroutinesProbes(timeout = 60 .seconds) {
141138 val logger = setupLogger()
142139
143140 val port: Int = findFreePort()
@@ -147,7 +144,7 @@ class KtorTransportTest {
147144 val serverReady = CompletableDeferred <Unit >()
148145 val dropServer = CompletableDeferred <Unit >()
149146
150- val service = SlowServiceImpl ()
147+ val impl = SlowServiceImpl ()
151148
152149 @Suppress(" detekt.GlobalCoroutineUsage" )
153150 val serverJob = GlobalScope .launch(CoroutineName (" server" )) {
@@ -171,22 +168,27 @@ class KtorTransportTest {
171168 }
172169 }
173170
174- registerService<SlowService > { service }
171+ registerService<SlowService > { impl }
175172 }
176173 }
177- }.start (wait = false )
174+ }.startSuspend (wait = false )
178175
179176 serverReady.complete(Unit )
180177
181178 dropServer.await()
182179
183- server.stop(shutdownGracePeriod = 100L , shutdownTimeout = 100L , timeUnit = TimeUnit . MILLISECONDS )
180+ server.stopSuspend(gracePeriodMillis = 100L , timeoutMillis = 300L )
184181 }
185182
186183 logger.info { " Server stopped" }
187184 }
188185
189186 val ktorClient = HttpClient (CIO ) {
187+ install(HttpRequestRetry ) {
188+ retryOnServerErrors(maxRetries = 5 )
189+ exponentialDelay()
190+ }
191+
190192 installKrpc {
191193 serialization {
192194 json()
@@ -200,32 +202,151 @@ class KtorTransportTest {
200202
201203 val rpcClient = ktorClient.rpc(" ws://0.0.0.0:$port /rpc" )
202204
203- launch {
205+ var cancellationExceptionCaught = false
206+ val job = launch {
204207 try {
205208 rpcClient.withService<SlowService >().verySlow()
206- error (" Must not be called" )
209+ fail (" Must not be called" )
207210 } catch (_: CancellationException ) {
208- logger.info { " Cancellation exception caught for RPC request " }
211+ cancellationExceptionCaught = true
209212 ensureActive()
210213 }
211214 }
212215
213- service .received.await()
216+ impl .received.await()
214217
215218 logger.info { " Received RPC request" }
216219
217220 dropServer.complete(Unit )
218221
219222 logger.info { " Waiting for RPC client to complete" }
220223
221- (rpcClient as KrpcClient ).awaitCompletion()
224+ rpcClient.awaitCompletion()
225+
226+ job.join()
227+
228+ assertTrue(cancellationExceptionCaught)
222229
223230 logger.info { " RPC client completed" }
224231
225232 ktorClient.close()
233+
234+ serverJob.join()
226235 newPool.close()
236+ }
237+
238+ @OptIn(DelicateCoroutinesApi ::class , ExperimentalCoroutinesApi ::class )
239+ @Test
240+ fun testServerTerminatesWhenClientWsDoes () = runTestWithCoroutinesProbes(timeout = 60 .seconds) {
241+ val logger = setupLogger()
242+
243+ val port: Int = findFreePort()
244+
245+ val newPool = Executors .newCachedThreadPool().asCoroutineDispatcher()
246+
247+ val serverReady = CompletableDeferred <Unit >()
248+ val dropServer = CompletableDeferred <Unit >()
227249
228- serverJob.cancel()
250+ val impl = SlowServiceImpl ()
251+ val sessionFinished = CompletableDeferred <Unit >()
252+
253+ @Suppress(" detekt.GlobalCoroutineUsage" )
254+ val serverJob = GlobalScope .launch(CoroutineName (" server" )) {
255+ withContext(newPool) {
256+ val server = embeddedServer(
257+ factory = Netty ,
258+ port = port,
259+ parentCoroutineContext = newPool,
260+ ) {
261+ install(Krpc )
262+
263+ routing {
264+ get {
265+ call.respondText(" hello" )
266+ }
267+
268+ rpc(" /rpc" ) {
269+ coroutineContext.job.invokeOnCompletion {
270+ sessionFinished.complete(Unit )
271+ }
272+
273+ rpcConfig {
274+ serialization {
275+ json()
276+ }
277+ }
278+
279+ registerService<SlowService > { impl }
280+ }
281+ }
282+ }.startSuspend(wait = false )
283+
284+ serverReady.complete(Unit )
285+
286+ dropServer.await()
287+
288+ server.stopSuspend(gracePeriodMillis = 100L , timeoutMillis = 300L )
289+ }
290+
291+ logger.info { " Server stopped" }
292+ }
293+
294+ val ktorClient = HttpClient (CIO ) {
295+ install(HttpRequestRetry ) {
296+ retryOnServerErrors(maxRetries = 5 )
297+ exponentialDelay()
298+ }
299+
300+ installKrpc {
301+ serialization {
302+ json()
303+ }
304+ }
305+ }
306+
307+ serverReady.await()
308+
309+ assertEquals(" hello" , ktorClient.get(" http://0.0.0.0:$port " ).bodyAsText())
310+
311+ val rpcClient = ktorClient.rpc(" ws://0.0.0.0:$port /rpc" )
312+
313+ var cancellationExceptionCaught = false
314+ val job = launch {
315+ try {
316+ rpcClient.withService<SlowService >().verySlow()
317+ fail(" Must not be called" )
318+ } catch (_: CancellationException ) {
319+ cancellationExceptionCaught = true
320+ ensureActive()
321+ }
322+ }
323+
324+ impl.received.await()
325+
326+ logger.info { " Received RPC request, Dropping client" }
327+
328+ ktorClient.cancel()
329+
330+ logger.info { " Waiting for RPC client to complete" }
331+
332+ rpcClient.awaitCompletion()
333+
334+ logger.info { " Waiting for request to complete" }
335+
336+ job.join()
337+
338+ assertTrue(cancellationExceptionCaught)
339+
340+ logger.info { " RPC client and request completed" }
341+
342+ sessionFinished.await()
343+
344+ logger.info { " Session finished" }
345+
346+ dropServer.complete(Unit )
347+ serverJob.join()
348+
349+ newPool.close()
229350 }
230351
231352 private fun findFreePort (): Int {
0 commit comments