Skip to content

Commit d601dda

Browse files
committed
I have no slightest idea why this makes tests work (#427)
1 parent 1fb903e commit d601dda

File tree

4 files changed

+39
-16
lines changed

4 files changed

+39
-16
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public interface ManagedChannel {
4040
*
4141
* @return whether the channel is terminated, as would be done by [isTerminated].
4242
*/
43-
public suspend fun awaitTermination(duration: Duration): Boolean
43+
public suspend fun awaitTermination(duration: Duration = Duration.INFINITE): Boolean
4444

4545
/**
4646
* Initiates an orderly shutdown in which preexisting calls continue but new calls are immediately canceled.

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendClientCalls.kt

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ private fun <Request, Response> rpcImpl(
159159
* there is room in the buffer.
160160
*/
161161
val responses = Channel<Response>(1)
162-
val ready = Ready()
162+
val ready = Ready { handler.isReady() }
163163

164164
handler.start(channelResponseListener(responses, ready), headers)
165165

@@ -200,7 +200,7 @@ private fun <Request, Response> rpcImpl(
200200
private fun <Response> channelResponseListener(
201201
responses: Channel<Response>,
202202
ready: Ready,
203-
) = clientCallListener<Response>(
203+
) = clientCallListener(
204204
onHeaders = {
205205
// todo check what happens here
206206
},
@@ -226,7 +226,7 @@ private fun <Response> channelResponseListener(
226226
// todo really needed?
227227
internal fun <T> Flow<T>.singleOrStatusFlow(
228228
expected: String,
229-
descriptor: Any
229+
descriptor: Any,
230230
): Flow<T> = flow {
231231
var found = false
232232
collect {
@@ -252,16 +252,22 @@ internal suspend fun <T> Flow<T>.singleOrStatus(
252252
descriptor: Any
253253
): T = singleOrStatusFlow(expected, descriptor).single()
254254

255-
internal class Ready {
255+
internal class Ready(private val isReallyReady: () -> Boolean) {
256256
// A CONFLATED channel never suspends to send, and two notifications of readiness are equivalent
257257
// to one
258258
private val channel = Channel<Unit>(Channel.CONFLATED)
259259

260260
fun onReady() {
261-
channel.trySend(Unit)
261+
channel.trySend(Unit).onFailure { e ->
262+
throw e ?: AssertionError(
263+
"Should be impossible; a CONFLATED channel should never return false on offer"
264+
)
265+
}
262266
}
263267

264268
suspend fun suspendUntilReady() {
265-
channel.receive()
269+
while (!isReallyReady()) {
270+
channel.receive()
271+
}
266272
}
267273
}

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendServerCalls.kt

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ public fun <Request, Response> CoroutineScope.serverStreamingServerMethodDefinit
7272
flow {
7373
requests
7474
.singleOrStatusFlow("request", descriptor)
75-
.collect { req ->
76-
implementation(req).collect { resp -> emit(resp) }
75+
.collect { request ->
76+
implementation(request).collect { response ->
77+
emit(response)
78+
}
7779
}
7880
}
7981
}
@@ -108,7 +110,7 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
108110
handler: ServerCall<Request, Response>,
109111
implementation: (Flow<Request>) -> Flow<Response>,
110112
): ServerCall.Listener<Request> {
111-
val readiness = Ready()
113+
val ready = Ready { handler.isReady()}
112114
val requestsChannel = Channel<Request>(1)
113115

114116
val requestsStarted = AtomicBoolean(false) // enforces read-once
@@ -118,8 +120,8 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
118120
"requests flow can only be collected once"
119121
}
120122

121-
handler.request(1)
122123
try {
124+
handler.request(1)
123125
for (request in requestsChannel) {
124126
emit(request)
125127
handler.request(1)
@@ -144,8 +146,10 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
144146
handler.sendHeaders(GrpcTrailers())
145147
}
146148
}
147-
readiness.suspendUntilReady()
148-
mutex.withLock { handler.sendMessage(it) }
149+
ready.suspendUntilReady()
150+
mutex.withLock {
151+
handler.sendMessage(it)
152+
}
149153
}
150154
}.exceptionOrNull()
151155
// check headers again once we're done collecting the response flow - if we received
@@ -180,7 +184,9 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
180184
}
181185
} ?: GrpcTrailers()
182186

183-
mutex.withLock { handler.close(closeStatus, trailers) }
187+
mutex.withLock {
188+
handler.close(closeStatus, trailers)
189+
}
184190
}
185191

186192
return serverCallListener(
@@ -209,7 +215,7 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
209215
requestsChannel.close()
210216
},
211217
onReady = {
212-
readiness.onReady()
218+
ready.onReady()
213219
},
214220
onComplete = {}
215221
)

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/RawClientServerTest.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
package kotlinx.rpc.grpc
66

77
import kotlinx.coroutines.CoroutineScope
8+
import kotlinx.coroutines.Job
9+
import kotlinx.coroutines.cancelAndJoin
810
import kotlinx.coroutines.flow.flowOf
911
import kotlinx.coroutines.flow.map
1012
import kotlinx.coroutines.flow.toList
@@ -98,6 +100,9 @@ class RawClientServerTest {
98100
methodDefinition: CoroutineScope.(MethodDescriptor<String, String>) -> ServerMethodDefinition<String, String>,
99101
block: suspend (GrpcChannel, MethodDescriptor<String, String>) -> Unit,
100102
) = kotlinx.coroutines.test.runTest {
103+
val serverJob = Job()
104+
val serverScope = CoroutineScope(serverJob)
105+
101106
val clientChannel = ManagedChannelBuilder("localhost", PORT).apply {
102107
usePlaintext()
103108
}.buildChannel()
@@ -122,13 +127,19 @@ class RawClientServerTest {
122127
methods = methods,
123128
schemaDescriptor = Unit,
124129
),
125-
methods = methods.map { methodDefinition(it) },
130+
methods = methods.map { serverScope.methodDefinition(it) },
126131
)
127132
)
128133
val server = Server(builder)
129134
server.start()
130135

131136
block(clientChannel.platformApi, descriptor)
137+
138+
serverJob.cancelAndJoin()
139+
clientChannel.shutdown()
140+
clientChannel.awaitTermination()
141+
server.shutdown()
142+
server.awaitTermination()
132143
}
133144

134145
companion object {

0 commit comments

Comments
 (0)