Skip to content

Commit 4a62ed2

Browse files
committed
grpc: Refactor server scope API
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 68f5a19 commit 4a62ed2

File tree

3 files changed

+113
-32
lines changed

3 files changed

+113
-32
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ServerInterceptor.kt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,28 @@
55
package kotlinx.rpc.grpc
66

77
import kotlinx.coroutines.flow.Flow
8-
import kotlinx.rpc.grpc.internal.GrpcCallOptions
8+
import kotlinx.coroutines.flow.FlowCollector
99
import kotlinx.rpc.grpc.internal.MethodDescriptor
1010

1111
public interface ServerCallScope<Request, Response> {
1212
public val method: MethodDescriptor<Request, Response>
13+
public val requestHeaders: GrpcTrailers
1314
public val responseHeaders: GrpcTrailers
14-
public fun onCancel(block: () -> Unit)
15-
public fun onComplete(block: () -> Unit)
15+
public val responseTrailers: GrpcTrailers
16+
17+
public fun onClose(block: (Status, GrpcTrailers) -> Unit)
18+
public fun close(status: Status, trailers: GrpcTrailers = GrpcTrailers()): Nothing
1619
public fun proceed(request: Flow<Request>): Flow<Response>
20+
21+
public suspend fun FlowCollector<Response>.proceedFlow(request: Flow<Request>) {
22+
proceed(request).collect {
23+
emit(it)
24+
}
25+
}
1726
}
1827

1928
public interface ServerInterceptor {
20-
public fun <Request, Response> intercept(
21-
scope: ServerCallScope<Request, Response>,
22-
requestHeaders: GrpcTrailers,
29+
public fun <Request, Response> ServerCallScope<Request, Response>.intercept(
2330
request: Flow<Request>,
2431
): Flow<Response>
2532
}

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendServerCalls.kt

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,12 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
159159

160160
val serverCallScope = ServerCallScopeImpl(
161161
method = descriptor,
162-
responseHeaders = GrpcTrailers(),
163162
interceptors = interceptors,
164163
implementation = implementation,
165164
requestHeaders = requestHeaders,
165+
serverCall = handler,
166166
)
167+
167168
val rpcJob = launch(GrpcContextElement.current()) {
168169
val mutex = Mutex()
169170
val headersSent = AtomicBoolean(false) // enforces only sending headers once
@@ -219,13 +220,13 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
219220

220221
mutex.withLock {
221222
handler.close(closeStatus, trailers)
223+
serverCallScope.onCloseFuture.complete(Pair(closeStatus, trailers))
222224
}
223225
}
224226

225227
return serverCallListener(
226228
state = ServerCallListenerState(),
227229
onCancel = {
228-
serverCallScope.onCancelFuture.complete(Unit)
229230
rpcJob.cancel("Cancellation received from client")
230231
},
231232
onMessage = { state, message: Request ->
@@ -251,9 +252,7 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
251252
onReady = {
252253
ready.onReady()
253254
},
254-
onComplete = {
255-
serverCallScope.onCompleteFuture.complete(Unit)
256-
}
255+
onComplete = { }
257256
)
258257
}
259258

@@ -270,29 +269,33 @@ private val unitKType = typeOf<Unit>()
270269

271270
private class ServerCallScopeImpl<Request, Response>(
272271
override val method: MethodDescriptor<Request, Response>,
273-
override val responseHeaders: GrpcTrailers,
274272
val interceptors: List<ServerInterceptor>,
275273
val implementation: (Flow<Request>) -> Flow<Response>,
276-
val requestHeaders: GrpcTrailers,
274+
override val requestHeaders: GrpcTrailers,
275+
val serverCall: ServerCall<Request, Response>,
277276
) : ServerCallScope<Request, Response> {
278277

279-
val onCancelFuture = CallbackFuture<Unit>()
280-
val onCompleteFuture = CallbackFuture<Unit>()
281-
val onCloseFuture = CallbackFuture<Pair<Status, GrpcTrailers>>()
278+
override val responseHeaders: GrpcTrailers = GrpcTrailers()
279+
override val responseTrailers: GrpcTrailers = GrpcTrailers()
280+
281+
// keeps track of already processed interceptors
282282
var interceptorIndex = 0
283+
val onCloseFuture = CallbackFuture<Pair<Status, GrpcTrailers>>()
283284

284-
override fun onCancel(block: () -> Unit) {
285-
onCancelFuture.onComplete { block() }
285+
override fun onClose(block: (Status, GrpcTrailers) -> Unit) {
286+
onCloseFuture.onComplete { block(it.first, it.second) }
286287
}
287288

288-
override fun onComplete(block: () -> Unit) {
289-
onCompleteFuture.onComplete { block() }
289+
override fun close(status: Status, trailers: GrpcTrailers): Nothing {
290+
// this will be cached by the rpcImpl() runCatching{} and turns it into a close()
291+
throw StatusException(status, trailers)
290292
}
291293

292294
override fun proceed(request: Flow<Request>): Flow<Response> {
293295
return if (interceptorIndex < interceptors.size) {
294-
interceptors[interceptorIndex++]
295-
.intercept(this, requestHeaders, request)
296+
with(interceptors[interceptorIndex++]) {
297+
intercept(request)
298+
}
296299
} else {
297300
implementation(request)
298301
}

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ServerInterceptorTest.kt

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,27 @@
55
package kotlinx.rpc.grpc.test.proto
66

77
import kotlinx.coroutines.flow.Flow
8+
import kotlinx.coroutines.flow.map
89
import kotlinx.rpc.RpcServer
910
import kotlinx.rpc.grpc.GrpcClient
1011
import kotlinx.rpc.grpc.GrpcTrailers
1112
import kotlinx.rpc.grpc.ServerCallScope
1213
import kotlinx.rpc.grpc.ServerInterceptor
14+
import kotlinx.rpc.grpc.Status
15+
import kotlinx.rpc.grpc.StatusCode
16+
import kotlinx.rpc.grpc.StatusException
17+
import kotlinx.rpc.grpc.statusCode
1318
import kotlinx.rpc.grpc.test.EchoRequest
1419
import kotlinx.rpc.grpc.test.EchoService
1520
import kotlinx.rpc.grpc.test.EchoServiceImpl
1621
import kotlinx.rpc.grpc.test.invoke
1722
import kotlinx.rpc.registerService
1823
import kotlinx.rpc.withService
1924
import kotlin.test.Test
25+
import kotlin.test.assertContains
2026
import kotlin.test.assertEquals
2127
import kotlin.test.assertFailsWith
28+
import kotlin.test.assertIs
2229

2330
class ServerInterceptorTest : GrpcProtoTest() {
2431

@@ -27,18 +34,82 @@ class ServerInterceptorTest : GrpcProtoTest() {
2734
}
2835

2936
@Test
30-
fun `throw during intercept - should fail with thrown exception`() {
31-
val error = assertFailsWith<IllegalStateException> {
32-
val interceptor = interceptor { scope, headers, request ->
33-
scope.proceed(request)
37+
fun `throw during intercept - should fail with unknown status on client`() {
38+
var cause: Throwable? = null
39+
val error = assertFailsWith<StatusException> {
40+
val interceptor = interceptor {
41+
onClose { status, _ -> cause = status.getCause() }
42+
// this exception is not propagated to the client (only as UNKNOWN status code)
43+
throw IllegalStateException("Failing in interceptor")
3444
}
3545
runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall)
3646
}
3747

38-
assertEquals(error.message, "Failing in interceptor")
48+
assertEquals(StatusCode.UNKNOWN, error.getStatus().statusCode)
49+
assertIs<IllegalStateException>(cause)
50+
assertEquals("Failing in interceptor", cause?.message)
3951
}
4052

4153

54+
@Test
55+
fun `close during intercept - should fail with correct status on client`() {
56+
val error = assertFailsWith<StatusException> {
57+
val interceptor = interceptor {
58+
close(Status(StatusCode.UNAUTHENTICATED, "Close in interceptor"), GrpcTrailers())
59+
}
60+
runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall)
61+
}
62+
63+
assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode)
64+
assertContains(error.message!!, "Close in interceptor")
65+
}
66+
67+
@Test
68+
fun `close during request flow - should fail with correct status on client`() {
69+
val error = assertFailsWith<StatusException> {
70+
val interceptor = interceptor {
71+
proceed(
72+
it.map {
73+
close(Status(StatusCode.UNAUTHENTICATED, "Close in request flow"), GrpcTrailers())
74+
}
75+
)
76+
}
77+
runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall)
78+
}
79+
80+
assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode)
81+
assertContains(error.message!!, "Close in request flow")
82+
}
83+
84+
@Test
85+
fun `close during response flow - should fail with correct status on client`() {
86+
val error = assertFailsWith<StatusException> {
87+
val interceptor = interceptor {
88+
proceed(it).map {
89+
close(Status(StatusCode.UNAUTHENTICATED, "Close in response flow"), GrpcTrailers())
90+
}
91+
}
92+
runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall)
93+
}
94+
95+
assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode)
96+
assertContains(error.message!!, "Close in response flow")
97+
}
98+
99+
@Test
100+
fun `close during onClose - should fail with correct status on client`() {
101+
val error = assertFailsWith<StatusException> {
102+
val interceptor = interceptor {
103+
onClose { _, _ -> close(Status(StatusCode.UNAUTHENTICATED, "Close in onClose"), GrpcTrailers()) }
104+
proceed(it)
105+
}
106+
runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall)
107+
}
108+
109+
assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode)
110+
assertContains(error.message!!, "Close in onClose")
111+
}
112+
42113
private suspend fun unaryCall(grpcClient: GrpcClient) {
43114
val service = grpcClient.withService<EchoService>()
44115
val response = service.UnaryEcho(EchoRequest { message = "Hello" })
@@ -48,16 +119,16 @@ class ServerInterceptorTest : GrpcProtoTest() {
48119

49120

50121
private fun interceptor(
51-
block: (ServerCallScope<Any, Any>, GrpcTrailers, Flow<Any>) -> Flow<Any>,
122+
block: ServerCallScope<Any, Any>.(Flow<Any>) -> Flow<Any>,
52123
): List<ServerInterceptor> {
53124
return listOf(object : ServerInterceptor {
54125
@Suppress("UNCHECKED_CAST")
55-
override fun <Req, Resp> intercept(
56-
scope: ServerCallScope<Req, Resp>,
57-
requestHeaders: GrpcTrailers,
126+
override fun <Req, Resp> ServerCallScope<Req, Resp>.intercept(
58127
request: Flow<Req>,
59128
): Flow<Resp> {
60-
return block(scope as ServerCallScope<Any, Any>, requestHeaders, request as Flow<Any>) as Flow<Resp>
129+
with(this as ServerCallScope<Any, Any>) {
130+
return block(request as Flow<Any>) as Flow<Resp>
131+
}
61132
}
62133
})
63134
}

0 commit comments

Comments
 (0)