Skip to content

Commit 68f5a19

Browse files
committed
grpc: Add server interceptor support
Signed-off-by: Johannes Zottele <[email protected]>
1 parent b27253a commit 68f5a19

File tree

11 files changed

+229
-37
lines changed

11 files changed

+229
-37
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcServer.kt

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ private typealias ResponseServer = Any
4444
*
4545
* @property port Specifies the port used by the server to listen for incoming connections.
4646
* @param parentContext
47-
* @param configure exposes platform-specific Server builder.
47+
* @param serverBuilder exposes platform-specific Server builder.
4848
*/
4949
public class GrpcServer internal constructor(
50-
override val port: Int = 8080,
51-
credentials: ServerCredentials? = null,
50+
override val port: Int,
51+
private val serverBuilder: ServerBuilder<*>,
52+
private val interceptors: List<ServerInterceptor>,
5253
messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver,
5354
parentContext: CoroutineContext = EmptyCoroutineContext,
54-
configure: ServerBuilder<*>.() -> Unit,
5555
) : RpcServer, Server {
5656
private val internalContext = SupervisorJob(parentContext[Job])
5757
private val internalScope = CoroutineScope(parentContext + internalContext)
@@ -61,9 +61,8 @@ public class GrpcServer internal constructor(
6161
private var isBuilt = false
6262
private lateinit var internalServer: Server
6363

64-
private val serverBuilder: ServerBuilder<*> = ServerBuilder(port, credentials).apply(configure)
6564
private val registry: MutableHandlerRegistry by lazy {
66-
MutableHandlerRegistry().apply { serverBuilder.fallbackHandlerRegistry(this) }
65+
MutableHandlerRegistry().apply { this@GrpcServer.serverBuilder.fallbackHandlerRegistry(this) }
6766
}
6867

6968
private val localRegistry = RpcInternalConcurrentHashMap<KClass<*>, ServerServiceDefinition>()
@@ -79,7 +78,7 @@ public class GrpcServer internal constructor(
7978
if (isBuilt) {
8079
registry.addService(definition)
8180
} else {
82-
serverBuilder.addService(definition)
81+
this@GrpcServer.serverBuilder.addService(definition)
8382
}
8483
}
8584

@@ -105,7 +104,8 @@ public class GrpcServer internal constructor(
105104
as? MethodDescriptor<RequestServer, ResponseServer>
106105
?: error("Expected a gRPC method descriptor")
107106

108-
it.toDefinitionOn(methodDescriptor, service)
107+
// TODO: support per service and per method interceptors (KRPC-222)
108+
it.toDefinitionOn(methodDescriptor, service, interceptors)
109109
}
110110

111111
return serverServiceDefinition(delegate.serviceDescriptor, methods)
@@ -114,29 +114,40 @@ public class GrpcServer internal constructor(
114114
private fun <@Grpc Service : Any> RpcCallable<Service>.toDefinitionOn(
115115
descriptor: MethodDescriptor<RequestServer, ResponseServer>,
116116
service: Service,
117+
interceptors: List<ServerInterceptor>,
117118
): ServerMethodDefinition<RequestServer, ResponseServer> {
118119
return when (descriptor.type) {
119120
MethodType.UNARY -> {
120-
internalScope.unaryServerMethodDefinition(descriptor, returnType.kType) { request ->
121+
internalScope.unaryServerMethodDefinition(descriptor, returnType.kType, interceptors) { request ->
121122
unaryInvokator.call(service, arrayOf(request)) as ResponseServer
122123
}
123124
}
124125

125126
MethodType.CLIENT_STREAMING -> {
126-
internalScope.clientStreamingServerMethodDefinition(descriptor, returnType.kType) { requests ->
127+
internalScope.clientStreamingServerMethodDefinition(
128+
descriptor,
129+
returnType.kType,
130+
interceptors
131+
) { requests ->
127132
unaryInvokator.call(service, arrayOf(requests)) as ResponseServer
128133
}
129134
}
130135

131136
MethodType.SERVER_STREAMING -> {
132-
internalScope.serverStreamingServerMethodDefinition(descriptor, returnType.kType) { request ->
137+
internalScope.serverStreamingServerMethodDefinition(
138+
descriptor, returnType.kType, interceptors
139+
) { request ->
133140
@Suppress("UNCHECKED_CAST")
134141
flowInvokator.call(service, arrayOf(request)) as Flow<ResponseServer>
135142
}
136143
}
137144

138145
MethodType.BIDI_STREAMING -> {
139-
internalScope.bidiStreamingServerMethodDefinition(descriptor, returnType.kType) { requests ->
146+
internalScope.bidiStreamingServerMethodDefinition(
147+
descriptor,
148+
returnType.kType,
149+
interceptors
150+
) { requests ->
140151
@Suppress("UNCHECKED_CAST")
141152
flowInvokator.call(service, arrayOf(requests)) as Flow<ResponseServer>
142153
}
@@ -152,7 +163,7 @@ public class GrpcServer internal constructor(
152163

153164
internal fun build() {
154165
if (buildLock.compareAndSet(expect = false, update = true)) {
155-
internalServer = Server(serverBuilder)
166+
internalServer = Server(this@GrpcServer.serverBuilder)
156167
isBuilt = true
157168
}
158169
}
@@ -192,12 +203,36 @@ public class GrpcServer internal constructor(
192203
*/
193204
public fun GrpcServer(
194205
port: Int,
195-
credentials: ServerCredentials? = null,
196-
messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver,
197206
parentContext: CoroutineContext = EmptyCoroutineContext,
198-
configure: ServerBuilder<*>.() -> Unit = {},
207+
configure: GrpcServerConfiguration.() -> Unit = {},
199208
builder: RpcServer.() -> Unit = {},
200209
): GrpcServer {
201-
return GrpcServer(port, credentials, messageCodecResolver, parentContext, configure).apply(builder)
210+
val config = GrpcServerConfiguration().apply(configure)
211+
val serverBuilder = ServerBuilder(port, config.credentials).apply {
212+
config.fallbackHandlerRegistry?.let { fallbackHandlerRegistry(it) }
213+
}
214+
return GrpcServer(port, serverBuilder, config.interceptors, config.messageCodecResolver, parentContext)
215+
.apply(builder)
202216
.apply { build() }
203217
}
218+
219+
public class GrpcServerConfiguration internal constructor() {
220+
internal var messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver
221+
internal var credentials: ServerCredentials? = null
222+
internal val interceptors: MutableList<ServerInterceptor> = mutableListOf()
223+
internal var fallbackHandlerRegistry: HandlerRegistry? = null
224+
internal var services: ServerBuilder<*>? = null
225+
226+
public fun useCredentials(credentials: ServerCredentials) {
227+
this.credentials = credentials
228+
}
229+
230+
public fun useMessageCodecResolver(messageCodecResolver: MessageCodecResolver) {
231+
this.messageCodecResolver = messageCodecResolver
232+
}
233+
234+
public fun intercept(vararg interceptors: ServerInterceptor) {
235+
this.interceptors.addAll(interceptors)
236+
}
237+
238+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.grpc
6+
7+
import kotlinx.coroutines.flow.Flow
8+
import kotlinx.rpc.grpc.internal.GrpcCallOptions
9+
import kotlinx.rpc.grpc.internal.MethodDescriptor
10+
11+
public interface ServerCallScope<Request, Response> {
12+
public val method: MethodDescriptor<Request, Response>
13+
public val responseHeaders: GrpcTrailers
14+
public fun onCancel(block: () -> Unit)
15+
public fun onComplete(block: () -> Unit)
16+
public fun proceed(request: Flow<Request>): Flow<Response>
17+
}
18+
19+
public interface ServerInterceptor {
20+
public fun <Request, Response> intercept(
21+
scope: ServerCallScope<Request, Response>,
22+
requestHeaders: GrpcTrailers,
23+
request: Flow<Request>,
24+
): Flow<Response>
25+
}

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/credentials.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ public expect abstract class ServerCredentials
1010
public expect class InsecureClientCredentials : ClientCredentials
1111
public expect class InsecureServerCredentials : ServerCredentials
1212

13+
// we need a wrapper for InsecureChannelCredentials as our constructor would conflict with the private
14+
// java constructor.
1315
internal expect fun createInsecureClientCredentials(): ClientCredentials
16+
internal expect fun createInsecureServerCredentials(): ServerCredentials
1417

1518
public expect class TlsClientCredentials : ClientCredentials
1619
public expect class TlsServerCredentials : ServerCredentials

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendServerCalls.kt

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import kotlinx.coroutines.launch
1717
import kotlinx.coroutines.sync.Mutex
1818
import kotlinx.coroutines.sync.withLock
1919
import kotlinx.rpc.grpc.GrpcTrailers
20+
import kotlinx.rpc.grpc.ServerCallScope
21+
import kotlinx.rpc.grpc.ServerInterceptor
2022
import kotlinx.rpc.grpc.Status
2123
import kotlinx.rpc.grpc.StatusCode
2224
import kotlinx.rpc.grpc.StatusException
@@ -29,14 +31,15 @@ import kotlin.reflect.typeOf
2931
public fun <Request, Response> CoroutineScope.unaryServerMethodDefinition(
3032
descriptor: MethodDescriptor<Request, Response>,
3133
responseKType: KType,
34+
interceptors: List<ServerInterceptor>,
3235
implementation: suspend (request: Request) -> Response,
3336
): ServerMethodDefinition<Request, Response> {
3437
val type = descriptor.type
3538
require(type == MethodType.UNARY) {
3639
"Expected a unary method descriptor but got $descriptor"
3740
}
3841

39-
return serverMethodDefinition(descriptor, responseKType) { requests ->
42+
return serverMethodDefinition(descriptor, responseKType, interceptors) { requests ->
4043
requests
4144
.singleOrStatusFlow("request", descriptor)
4245
.map { implementation(it) }
@@ -47,14 +50,15 @@ public fun <Request, Response> CoroutineScope.unaryServerMethodDefinition(
4750
public fun <Request, Response> CoroutineScope.clientStreamingServerMethodDefinition(
4851
descriptor: MethodDescriptor<Request, Response>,
4952
responseKType: KType,
53+
interceptors: List<ServerInterceptor>,
5054
implementation: suspend (requests: Flow<Request>) -> Response,
5155
): ServerMethodDefinition<Request, Response> {
5256
val type = descriptor.type
5357
require(type == MethodType.CLIENT_STREAMING) {
5458
"Expected a client streaming method descriptor but got $descriptor"
5559
}
5660

57-
return serverMethodDefinition(descriptor, responseKType) { requests ->
61+
return serverMethodDefinition(descriptor, responseKType, interceptors) { requests ->
5862
flow {
5963
val response = implementation(requests)
6064
emit(response)
@@ -66,14 +70,15 @@ public fun <Request, Response> CoroutineScope.clientStreamingServerMethodDefinit
6670
public fun <Request, Response> CoroutineScope.serverStreamingServerMethodDefinition(
6771
descriptor: MethodDescriptor<Request, Response>,
6872
responseKType: KType,
73+
interceptors: List<ServerInterceptor>,
6974
implementation: (request: Request) -> Flow<Response>,
7075
): ServerMethodDefinition<Request, Response> {
7176
val type = descriptor.type
7277
require(type == MethodType.SERVER_STREAMING) {
7378
"Expected a server streaming method descriptor but got $descriptor"
7479
}
7580

76-
return serverMethodDefinition(descriptor, responseKType) { requests ->
81+
return serverMethodDefinition(descriptor, responseKType, interceptors) { requests ->
7782
flow {
7883
requests
7984
.singleOrStatusFlow("request", descriptor)
@@ -90,36 +95,44 @@ public fun <Request, Response> CoroutineScope.serverStreamingServerMethodDefinit
9095
public fun <Request, Response> CoroutineScope.bidiStreamingServerMethodDefinition(
9196
descriptor: MethodDescriptor<Request, Response>,
9297
responseKType: KType,
98+
interceptors: List<ServerInterceptor>,
9399
implementation: (requests: Flow<Request>) -> Flow<Response>,
94100
): ServerMethodDefinition<Request, Response> {
95101
val type = descriptor.type
96102
check(type == MethodType.BIDI_STREAMING) {
97103
"Expected a bidi streaming method descriptor but got $descriptor"
98104
}
99105

100-
return serverMethodDefinition(descriptor, responseKType, implementation)
106+
return serverMethodDefinition(descriptor, responseKType, interceptors, implementation)
101107
}
102108

103109
private fun <Request, Response> CoroutineScope.serverMethodDefinition(
104110
descriptor: MethodDescriptor<Request, Response>,
105111
responseKType: KType,
112+
interceptors: List<ServerInterceptor>,
106113
implementation: (Flow<Request>) -> Flow<Response>,
107-
): ServerMethodDefinition<Request, Response> = serverMethodDefinition(descriptor, serverCallHandler(responseKType, implementation))
114+
): ServerMethodDefinition<Request, Response> =
115+
serverMethodDefinition(descriptor, serverCallHandler(descriptor, responseKType, interceptors, implementation))
108116

109117
private fun <Request, Response> CoroutineScope.serverCallHandler(
118+
descriptor: MethodDescriptor<Request, Response>,
110119
responseKType: KType,
120+
interceptors: List<ServerInterceptor>,
111121
implementation: (Flow<Request>) -> Flow<Response>,
112122
): ServerCallHandler<Request, Response> =
113-
ServerCallHandler { call, _ ->
114-
serverCallListenerImpl(call, responseKType, implementation)
123+
ServerCallHandler { call, headers ->
124+
serverCallListenerImpl(descriptor, call, responseKType, interceptors, implementation, headers)
115125
}
116126

117127
private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
128+
descriptor: MethodDescriptor<Request, Response>,
118129
handler: ServerCall<Request, Response>,
119130
responseKType: KType,
131+
interceptors: List<ServerInterceptor>,
120132
implementation: (Flow<Request>) -> Flow<Response>,
133+
requestHeaders: GrpcTrailers,
121134
): ServerCall.Listener<Request> {
122-
val ready = Ready { handler.isReady()}
135+
val ready = Ready { handler.isReady() }
123136
val requestsChannel = Channel<Request>(1)
124137

125138
val requestsStarted = AtomicBoolean(false) // enforces read-once
@@ -144,11 +157,18 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
144157
}
145158
}
146159

160+
val serverCallScope = ServerCallScopeImpl(
161+
method = descriptor,
162+
responseHeaders = GrpcTrailers(),
163+
interceptors = interceptors,
164+
implementation = implementation,
165+
requestHeaders = requestHeaders,
166+
)
147167
val rpcJob = launch(GrpcContextElement.current()) {
148168
val mutex = Mutex()
149169
val headersSent = AtomicBoolean(false) // enforces only sending headers once
150170
val failure = runCatching {
151-
implementation(requests).collect { response ->
171+
serverCallScope.proceed(requests).collect { response ->
152172
@Suppress("UNCHECKED_CAST")
153173
// fix for KRPC-173
154174
val value = if (responseKType == unitKType) Unit as Response else response
@@ -205,6 +225,7 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
205225
return serverCallListener(
206226
state = ServerCallListenerState(),
207227
onCancel = {
228+
serverCallScope.onCancelFuture.complete(Unit)
208229
rpcJob.cancel("Cancellation received from client")
209230
},
210231
onMessage = { state, message: Request ->
@@ -230,7 +251,9 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
230251
onReady = {
231252
ready.onReady()
232253
},
233-
onComplete = {}
254+
onComplete = {
255+
serverCallScope.onCompleteFuture.complete(Unit)
256+
}
234257
)
235258
}
236259

@@ -242,4 +265,36 @@ private class ServerCallListenerState {
242265
var isReceiving = true
243266
}
244267

245-
private val unitKType = typeOf<Unit>()
268+
private val unitKType = typeOf<Unit>()
269+
270+
271+
private class ServerCallScopeImpl<Request, Response>(
272+
override val method: MethodDescriptor<Request, Response>,
273+
override val responseHeaders: GrpcTrailers,
274+
val interceptors: List<ServerInterceptor>,
275+
val implementation: (Flow<Request>) -> Flow<Response>,
276+
val requestHeaders: GrpcTrailers,
277+
) : ServerCallScope<Request, Response> {
278+
279+
val onCancelFuture = CallbackFuture<Unit>()
280+
val onCompleteFuture = CallbackFuture<Unit>()
281+
val onCloseFuture = CallbackFuture<Pair<Status, GrpcTrailers>>()
282+
var interceptorIndex = 0
283+
284+
override fun onCancel(block: () -> Unit) {
285+
onCancelFuture.onComplete { block() }
286+
}
287+
288+
override fun onComplete(block: () -> Unit) {
289+
onCompleteFuture.onComplete { block() }
290+
}
291+
292+
override fun proceed(request: Flow<Request>): Flow<Response> {
293+
return if (interceptorIndex < interceptors.size) {
294+
interceptors[interceptorIndex++]
295+
.intercept(this, requestHeaders, request)
296+
} else {
297+
implementation(request)
298+
}
299+
}
300+
}

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/BaseGrpcServiceTest.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ abstract class BaseGrpcServiceTest {
3030
) = runTest {
3131
val server = GrpcServer(
3232
port = PORT,
33-
messageCodecResolver = resolver,
3433
parentContext = coroutineContext,
34+
configure = {
35+
useMessageCodecResolver(resolver)
36+
},
3537
builder = {
3638
registerService(kClass) { impl }
3739
}

0 commit comments

Comments
 (0)