Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public interface ManagedChannel {
*
* @return whether the channel is terminated, as would be done by [isTerminated].
*/
public suspend fun awaitTermination(duration: Duration): Boolean
public suspend fun awaitTermination(duration: Duration = Duration.INFINITE): Boolean

/**
* Initiates an orderly shutdown in which preexisting calls continue but new calls are immediately canceled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private fun <Request, Response> rpcImpl(
* there is room in the buffer.
*/
val responses = Channel<Response>(1)
val ready = Ready()
val ready = Ready { handler.isReady() }

handler.start(channelResponseListener(responses, ready), headers)

Expand Down Expand Up @@ -200,7 +200,7 @@ private fun <Request, Response> rpcImpl(
private fun <Response> channelResponseListener(
responses: Channel<Response>,
ready: Ready,
) = clientCallListener<Response>(
) = clientCallListener(
onHeaders = {
// todo check what happens here
},
Expand All @@ -226,7 +226,7 @@ private fun <Response> channelResponseListener(
// todo really needed?
internal fun <T> Flow<T>.singleOrStatusFlow(
expected: String,
descriptor: Any
descriptor: Any,
): Flow<T> = flow {
var found = false
collect {
Expand All @@ -252,16 +252,22 @@ internal suspend fun <T> Flow<T>.singleOrStatus(
descriptor: Any
): T = singleOrStatusFlow(expected, descriptor).single()

internal class Ready {
internal class Ready(private val isReallyReady: () -> Boolean) {
// A CONFLATED channel never suspends to send, and two notifications of readiness are equivalent
// to one
private val channel = Channel<Unit>(Channel.CONFLATED)

fun onReady() {
channel.trySend(Unit)
channel.trySend(Unit).onFailure { e ->
throw e ?: AssertionError(
"Should be impossible; a CONFLATED channel should never return false on offer"
)
}
}

suspend fun suspendUntilReady() {
channel.receive()
while (!isReallyReady()) {
channel.receive()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ public fun <Request, Response> CoroutineScope.serverStreamingServerMethodDefinit
flow {
requests
.singleOrStatusFlow("request", descriptor)
.collect { req ->
implementation(req).collect { resp -> emit(resp) }
.collect { request ->
implementation(request).collect { response ->
emit(response)
}
}
}
}
Expand Down Expand Up @@ -108,7 +110,7 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
handler: ServerCall<Request, Response>,
implementation: (Flow<Request>) -> Flow<Response>,
): ServerCall.Listener<Request> {
val readiness = Ready()
val ready = Ready { handler.isReady()}
val requestsChannel = Channel<Request>(1)

val requestsStarted = AtomicBoolean(false) // enforces read-once
Expand All @@ -118,8 +120,8 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
"requests flow can only be collected once"
}

handler.request(1)
try {
handler.request(1)
for (request in requestsChannel) {
emit(request)
handler.request(1)
Expand All @@ -144,8 +146,10 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
handler.sendHeaders(GrpcTrailers())
}
}
readiness.suspendUntilReady()
mutex.withLock { handler.sendMessage(it) }
ready.suspendUntilReady()
mutex.withLock {
handler.sendMessage(it)
}
}
}.exceptionOrNull()
// check headers again once we're done collecting the response flow - if we received
Expand Down Expand Up @@ -180,7 +184,9 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
}
} ?: GrpcTrailers()

mutex.withLock { handler.close(closeStatus, trailers) }
mutex.withLock {
handler.close(closeStatus, trailers)
}
}

return serverCallListener(
Expand Down Expand Up @@ -209,7 +215,7 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
requestsChannel.close()
},
onReady = {
readiness.onReady()
ready.onReady()
},
onComplete = {}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package kotlinx.rpc.grpc

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
Expand Down Expand Up @@ -98,6 +100,9 @@ class RawClientServerTest {
methodDefinition: CoroutineScope.(MethodDescriptor<String, String>) -> ServerMethodDefinition<String, String>,
block: suspend (GrpcChannel, MethodDescriptor<String, String>) -> Unit,
) = kotlinx.coroutines.test.runTest {
val serverJob = Job()
val serverScope = CoroutineScope(serverJob)

val clientChannel = ManagedChannelBuilder("localhost", PORT).apply {
usePlaintext()
}.buildChannel()
Expand All @@ -122,13 +127,19 @@ class RawClientServerTest {
methods = methods,
schemaDescriptor = Unit,
),
methods = methods.map { methodDefinition(it) },
methods = methods.map { serverScope.methodDefinition(it) },
)
)
val server = Server(builder)
server.start()

block(clientChannel.platformApi, descriptor)

serverJob.cancelAndJoin()
clientChannel.shutdown()
clientChannel.awaitTermination()
server.shutdown()
server.awaitTermination()
}

companion object {
Expand Down
Loading