@@ -17,6 +17,8 @@ import kotlinx.coroutines.launch
17
17
import kotlinx.coroutines.sync.Mutex
18
18
import kotlinx.coroutines.sync.withLock
19
19
import kotlinx.rpc.grpc.GrpcTrailers
20
+ import kotlinx.rpc.grpc.ServerCallScope
21
+ import kotlinx.rpc.grpc.ServerInterceptor
20
22
import kotlinx.rpc.grpc.Status
21
23
import kotlinx.rpc.grpc.StatusCode
22
24
import kotlinx.rpc.grpc.StatusException
@@ -29,14 +31,15 @@ import kotlin.reflect.typeOf
29
31
public fun <Request , Response > CoroutineScope.unaryServerMethodDefinition (
30
32
descriptor : MethodDescriptor <Request , Response >,
31
33
responseKType : KType ,
34
+ interceptors : List <ServerInterceptor >,
32
35
implementation : suspend (request: Request ) -> Response ,
33
36
): ServerMethodDefinition <Request , Response > {
34
37
val type = descriptor.type
35
38
require(type == MethodType .UNARY ) {
36
39
" Expected a unary method descriptor but got $descriptor "
37
40
}
38
41
39
- return serverMethodDefinition(descriptor, responseKType) { requests ->
42
+ return serverMethodDefinition(descriptor, responseKType, interceptors ) { requests ->
40
43
requests
41
44
.singleOrStatusFlow(" request" , descriptor)
42
45
.map { implementation(it) }
@@ -47,14 +50,15 @@ public fun <Request, Response> CoroutineScope.unaryServerMethodDefinition(
47
50
public fun <Request , Response > CoroutineScope.clientStreamingServerMethodDefinition (
48
51
descriptor : MethodDescriptor <Request , Response >,
49
52
responseKType : KType ,
53
+ interceptors : List <ServerInterceptor >,
50
54
implementation : suspend (requests: Flow <Request >) -> Response ,
51
55
): ServerMethodDefinition <Request , Response > {
52
56
val type = descriptor.type
53
57
require(type == MethodType .CLIENT_STREAMING ) {
54
58
" Expected a client streaming method descriptor but got $descriptor "
55
59
}
56
60
57
- return serverMethodDefinition(descriptor, responseKType) { requests ->
61
+ return serverMethodDefinition(descriptor, responseKType, interceptors ) { requests ->
58
62
flow {
59
63
val response = implementation(requests)
60
64
emit(response)
@@ -66,14 +70,15 @@ public fun <Request, Response> CoroutineScope.clientStreamingServerMethodDefinit
66
70
public fun <Request , Response > CoroutineScope.serverStreamingServerMethodDefinition (
67
71
descriptor : MethodDescriptor <Request , Response >,
68
72
responseKType : KType ,
73
+ interceptors : List <ServerInterceptor >,
69
74
implementation : (request: Request ) -> Flow <Response >,
70
75
): ServerMethodDefinition <Request , Response > {
71
76
val type = descriptor.type
72
77
require(type == MethodType .SERVER_STREAMING ) {
73
78
" Expected a server streaming method descriptor but got $descriptor "
74
79
}
75
80
76
- return serverMethodDefinition(descriptor, responseKType) { requests ->
81
+ return serverMethodDefinition(descriptor, responseKType, interceptors ) { requests ->
77
82
flow {
78
83
requests
79
84
.singleOrStatusFlow(" request" , descriptor)
@@ -90,36 +95,44 @@ public fun <Request, Response> CoroutineScope.serverStreamingServerMethodDefinit
90
95
public fun <Request , Response > CoroutineScope.bidiStreamingServerMethodDefinition (
91
96
descriptor : MethodDescriptor <Request , Response >,
92
97
responseKType : KType ,
98
+ interceptors : List <ServerInterceptor >,
93
99
implementation : (requests: Flow <Request >) -> Flow <Response >,
94
100
): ServerMethodDefinition <Request , Response > {
95
101
val type = descriptor.type
96
102
check(type == MethodType .BIDI_STREAMING ) {
97
103
" Expected a bidi streaming method descriptor but got $descriptor "
98
104
}
99
105
100
- return serverMethodDefinition(descriptor, responseKType, implementation)
106
+ return serverMethodDefinition(descriptor, responseKType, interceptors, implementation)
101
107
}
102
108
103
109
private fun <Request , Response > CoroutineScope.serverMethodDefinition (
104
110
descriptor : MethodDescriptor <Request , Response >,
105
111
responseKType : KType ,
112
+ interceptors : List <ServerInterceptor >,
106
113
implementation : (Flow <Request >) -> Flow <Response >,
107
- ): ServerMethodDefinition <Request , Response > = serverMethodDefinition(descriptor, serverCallHandler(responseKType, implementation))
114
+ ): ServerMethodDefinition <Request , Response > =
115
+ serverMethodDefinition(descriptor, serverCallHandler(descriptor, responseKType, interceptors, implementation))
108
116
109
117
private fun <Request , Response > CoroutineScope.serverCallHandler (
118
+ descriptor : MethodDescriptor <Request , Response >,
110
119
responseKType : KType ,
120
+ interceptors : List <ServerInterceptor >,
111
121
implementation : (Flow <Request >) -> Flow <Response >,
112
122
): ServerCallHandler <Request , Response > =
113
- ServerCallHandler { call, _ ->
114
- serverCallListenerImpl(call, responseKType, implementation)
123
+ ServerCallHandler { call, headers ->
124
+ serverCallListenerImpl(descriptor, call, responseKType, interceptors, implementation, headers )
115
125
}
116
126
117
127
private fun <Request , Response > CoroutineScope.serverCallListenerImpl (
128
+ descriptor : MethodDescriptor <Request , Response >,
118
129
handler : ServerCall <Request , Response >,
119
130
responseKType : KType ,
131
+ interceptors : List <ServerInterceptor >,
120
132
implementation : (Flow <Request >) -> Flow <Response >,
133
+ requestHeaders : GrpcTrailers ,
121
134
): ServerCall .Listener <Request > {
122
- val ready = Ready { handler.isReady()}
135
+ val ready = Ready { handler.isReady() }
123
136
val requestsChannel = Channel <Request >(1 )
124
137
125
138
val requestsStarted = AtomicBoolean (false ) // enforces read-once
@@ -144,11 +157,18 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
144
157
}
145
158
}
146
159
160
+ val serverCallScope = ServerCallScopeImpl (
161
+ method = descriptor,
162
+ responseHeaders = GrpcTrailers (),
163
+ interceptors = interceptors,
164
+ implementation = implementation,
165
+ requestHeaders = requestHeaders,
166
+ )
147
167
val rpcJob = launch(GrpcContextElement .current()) {
148
168
val mutex = Mutex ()
149
169
val headersSent = AtomicBoolean (false ) // enforces only sending headers once
150
170
val failure = runCatching {
151
- implementation (requests).collect { response ->
171
+ serverCallScope.proceed (requests).collect { response ->
152
172
@Suppress(" UNCHECKED_CAST" )
153
173
// fix for KRPC-173
154
174
val value = if (responseKType == unitKType) Unit as Response else response
@@ -205,6 +225,7 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
205
225
return serverCallListener(
206
226
state = ServerCallListenerState (),
207
227
onCancel = {
228
+ serverCallScope.onCancelFuture.complete(Unit )
208
229
rpcJob.cancel(" Cancellation received from client" )
209
230
},
210
231
onMessage = { state, message: Request ->
@@ -230,7 +251,9 @@ private fun <Request, Response> CoroutineScope.serverCallListenerImpl(
230
251
onReady = {
231
252
ready.onReady()
232
253
},
233
- onComplete = {}
254
+ onComplete = {
255
+ serverCallScope.onCompleteFuture.complete(Unit )
256
+ }
234
257
)
235
258
}
236
259
@@ -242,4 +265,36 @@ private class ServerCallListenerState {
242
265
var isReceiving = true
243
266
}
244
267
245
- private val unitKType = typeOf<Unit >()
268
+ private val unitKType = typeOf<Unit >()
269
+
270
+
271
+ private class ServerCallScopeImpl <Request , Response >(
272
+ override val method : MethodDescriptor <Request , Response >,
273
+ override val responseHeaders : GrpcTrailers ,
274
+ val interceptors : List <ServerInterceptor >,
275
+ val implementation : (Flow <Request >) -> Flow <Response >,
276
+ val requestHeaders : GrpcTrailers ,
277
+ ) : ServerCallScope<Request, Response> {
278
+
279
+ val onCancelFuture = CallbackFuture <Unit >()
280
+ val onCompleteFuture = CallbackFuture <Unit >()
281
+ val onCloseFuture = CallbackFuture <Pair <Status , GrpcTrailers >>()
282
+ var interceptorIndex = 0
283
+
284
+ override fun onCancel (block : () -> Unit ) {
285
+ onCancelFuture.onComplete { block() }
286
+ }
287
+
288
+ override fun onComplete (block : () -> Unit ) {
289
+ onCompleteFuture.onComplete { block() }
290
+ }
291
+
292
+ override fun proceed (request : Flow <Request >): Flow <Response > {
293
+ return if (interceptorIndex < interceptors.size) {
294
+ interceptors[interceptorIndex++ ]
295
+ .intercept(this , requestHeaders, request)
296
+ } else {
297
+ implementation(request)
298
+ }
299
+ }
300
+ }
0 commit comments