Skip to content

Commit 8523fb3

Browse files
committed
Added BackPressure tests and fixed client side streaming
1 parent 7c54fee commit 8523fb3

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamContext.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ internal class ServerStreamContext {
5555
fun prepareClientStream(streamId: String, elementKind: KSerializer<Any?>): Flow<Any?> {
5656
val callId = currentCallId.get() ?: error("No call id")
5757

58-
val channel = Channel<Any?>(Channel.UNLIMITED)
58+
val channel = Channel<Any?>()
5959

6060
@Suppress("UNCHECKED_CAST")
6161
val map = streams.computeIfAbsent(callId) { RpcInternalConcurrentHashMap() }
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.krpc.test
6+
7+
import kotlinx.atomicfu.atomic
8+
import kotlinx.coroutines.CompletableDeferred
9+
import kotlinx.coroutines.async
10+
import kotlinx.coroutines.cancelAndJoin
11+
import kotlinx.coroutines.flow.Flow
12+
import kotlinx.coroutines.flow.flow
13+
import kotlinx.coroutines.flow.map
14+
import kotlinx.coroutines.flow.toList
15+
import kotlinx.coroutines.job
16+
import kotlinx.coroutines.test.TestResult
17+
import kotlinx.coroutines.test.TestScope
18+
import kotlinx.coroutines.yield
19+
import kotlinx.rpc.annotations.Rpc
20+
import kotlinx.rpc.krpc.internal.logging.RpcInternalDumpLoggerContainer
21+
import kotlinx.rpc.krpc.rpcClientConfig
22+
import kotlinx.rpc.krpc.rpcServerConfig
23+
import kotlinx.rpc.krpc.serialization.json.json
24+
import kotlinx.rpc.registerService
25+
import kotlinx.rpc.withService
26+
import kotlin.test.Test
27+
import kotlin.test.assertEquals
28+
import kotlin.time.Duration
29+
import kotlin.time.Duration.Companion.seconds
30+
31+
@Rpc
32+
interface BackPressure {
33+
suspend fun plain()
34+
35+
fun serverStream(num: Int): Flow<Int>
36+
37+
suspend fun clientStream(flow: Flow<Int>)
38+
}
39+
40+
class BackPressureImpl : BackPressure {
41+
val plainCounter = atomic(0)
42+
val serverStreamCounter = atomic(0)
43+
val clientStreamCounter = atomic(0)
44+
val entered = CompletableDeferred<Unit>()
45+
val fence = CompletableDeferred<Unit>()
46+
47+
suspend fun awaitCounter(value: Int, counter: BackPressureImpl.() -> Int) {
48+
while (counter() != value) {
49+
yield()
50+
}
51+
}
52+
53+
override suspend fun plain() {
54+
plainCounter.incrementAndGet()
55+
}
56+
57+
override fun serverStream(num: Int): Flow<Int> {
58+
return flow {
59+
repeat(num) {
60+
serverStreamCounter.incrementAndGet()
61+
emit(it)
62+
}
63+
}
64+
}
65+
66+
val consumed = mutableListOf<Int>()
67+
override suspend fun clientStream(flow: Flow<Int>) {
68+
flow.collect {
69+
if (it == 0) {
70+
entered.complete(Unit)
71+
fence.await()
72+
}
73+
consumed.add(it)
74+
}
75+
}
76+
}
77+
78+
class BackPressureTest : BackPressureTestBase() {
79+
@Test
80+
fun buffer_size_1_server() = runServerTest(perCallBufferSize = 1)
81+
82+
@Test
83+
fun buffer_size_10_server() = runServerTest(perCallBufferSize = 10)
84+
85+
@Test
86+
fun buffer_size_1_client() = runClientTest(perCallBufferSize = 1)
87+
88+
@Test
89+
fun buffer_size_10_client() = runClientTest(perCallBufferSize = 10)
90+
}
91+
92+
// `+2` explanation:
93+
// 1. the first element is sent and processed by the client
94+
// 2. the second element is sent and is suspended on the client
95+
// because the processing of the first element is not finished yet
96+
// 3. the third element is n from the flow and suspended on the server side
97+
abstract class BackPressureTestBase {
98+
protected fun runServerTest(
99+
perCallBufferSize: Int,
100+
timeout: Duration = 10.seconds,
101+
) = runTest(perCallBufferSize, timeout) { service, impl ->
102+
val flowList = async {
103+
service.serverStream(1000).map {
104+
if (it == 0) {
105+
impl.entered.complete(Unit)
106+
impl.fence.await()
107+
}
108+
}.toList()
109+
}
110+
111+
impl.entered.await()
112+
impl.awaitCounter(perCallBufferSize + 2) { serverStreamCounter.value }
113+
114+
repeat(1000) {
115+
service.plain()
116+
}
117+
118+
impl.awaitCounter(1000) { plainCounter.value }
119+
120+
assertEquals(perCallBufferSize + 2, impl.serverStreamCounter.value)
121+
impl.fence.complete(Unit)
122+
impl.awaitCounter(1000) { serverStreamCounter.value }
123+
assertEquals(1000, flowList.await().size)
124+
}
125+
126+
protected fun runClientTest(
127+
perCallBufferSize: Int,
128+
timeout: Duration = 10.seconds,
129+
) = runTest(perCallBufferSize, timeout) { service, impl ->
130+
val flowList = async {
131+
service.clientStream(flow {
132+
repeat(1000) {
133+
impl.clientStreamCounter.incrementAndGet()
134+
emit(it)
135+
}
136+
})
137+
}
138+
139+
impl.entered.await()
140+
impl.awaitCounter(perCallBufferSize + 2) { clientStreamCounter.value }
141+
142+
repeat(1000) {
143+
service.plain()
144+
}
145+
146+
impl.awaitCounter(1000) { plainCounter.value }
147+
148+
assertEquals(0, impl.consumed.size)
149+
assertEquals(perCallBufferSize + 2, impl.clientStreamCounter.value)
150+
impl.fence.complete(Unit)
151+
impl.awaitCounter(1000) { clientStreamCounter.value }
152+
flowList.await()
153+
assertEquals(1000, impl.consumed.size)
154+
}
155+
156+
protected fun runTest(
157+
perCallBufferSize: Int,
158+
timeout: Duration = 10.seconds,
159+
body: suspend TestScope.(BackPressure, BackPressureImpl) -> Unit,
160+
): TestResult = kotlinx.coroutines.test.runTest(timeout = timeout) {
161+
val transport = LocalTransport(coroutineContext, recordTimestamps = false)
162+
val clientConfig = rpcClientConfig {
163+
serialization {
164+
json()
165+
}
166+
167+
connector {
168+
this.perCallBufferSize = perCallBufferSize
169+
}
170+
}
171+
val client = KrpcTestClient(clientConfig, transport.client)
172+
val serverConfig = rpcServerConfig {
173+
serialization {
174+
json()
175+
}
176+
177+
connector {
178+
this.perCallBufferSize = perCallBufferSize
179+
}
180+
}
181+
val server = KrpcTestServer(serverConfig, transport.server)
182+
val impl = BackPressureImpl()
183+
server.registerService<BackPressure> { impl }
184+
val service = client.withService<BackPressure>()
185+
186+
try {
187+
body(service, impl)
188+
} finally {
189+
RpcInternalDumpLoggerContainer.set(null)
190+
client.close()
191+
server.close()
192+
client.awaitCompletion()
193+
server.awaitCompletion()
194+
transport.coroutineContext.job.cancelAndJoin()
195+
}
196+
}
197+
}

0 commit comments

Comments
 (0)