Skip to content

Commit b27253a

Browse files
committed
grpc: Add client interceptor support
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 2319098 commit b27253a

File tree

4 files changed

+52
-22
lines changed

4 files changed

+52
-22
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendClientCalls.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ private class ClientCallScopeImpl<Request, Response>(
307307
} catch (exception: Throwable) {
308308
cause = exception
309309
if (exception !is StatusException) {
310-
val status = Status(StatusCode.INTERNAL, "Interceptor threw an error", exception)
310+
val status = Status(StatusCode.CANCELLED, "Interceptor threw an error", exception)
311311
cause = StatusException(status)
312312
}
313313
}

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ClientInterceptorTest.kt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class ClientInterceptorTest: GrpcProtoTest() {
6767
runGrpcTest(clientInterceptors = interceptor, test = ::unaryCall)
6868
}
6969

70+
assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode)
7071
assertIs<IllegalStateException>(error.cause)
7172
assertEquals("Failing in onHeader", error.cause?.message)
7273
}
@@ -83,6 +84,7 @@ class ClientInterceptorTest: GrpcProtoTest() {
8384
runGrpcTest(clientInterceptors = interceptor, test = ::unaryCall)
8485
}
8586

87+
assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode)
8688
assertIs<IllegalStateException>(error.cause)
8789
assertEquals("Failing in onClose", error.cause?.message)
8890
}
@@ -129,13 +131,13 @@ class ClientInterceptorTest: GrpcProtoTest() {
129131
}
130132

131133
@Test
132-
fun `append a response message once closed`() = repeat(1000) {
134+
fun `append a response message once closed`() {
133135
val interceptor = interceptor { scope, req -> channelFlow {
134136
scope.proceed(req).collect {
135137
trySend(it)
136138
}
137-
scope.onClose { _, _ ->
138-
trySend(EchoResponse { message = "Appended-after-close" })
139+
scope.onClose { status, _ ->
140+
trySend(EchoResponse { message = "Appended-after-close-with-${status.statusCode}" })
139141
}
140142
} }
141143

@@ -148,10 +150,8 @@ class ClientInterceptorTest: GrpcProtoTest() {
148150
emit(EchoRequest { message = "Eccchhooo" })
149151
}
150152
}).toList()
151-
152-
println("Respone messages: ${responses.map { it.message }}")
153153
assertEquals(6, responses.size)
154-
assertTrue(responses.any { it.message == "Appended-after-close" })
154+
assertTrue(responses.any { it.message == "Appended-after-close-with-OK" })
155155
}
156156
}
157157

grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Status.native.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ public actual class Status internal constructor(
1212
public actual fun getDescription(): String? = description
1313

1414
public actual fun getCause(): Throwable? = cause
15+
16+
override fun toString(): String {
17+
return "Status(description=$description, statusCode=$statusCode, cause=$cause)"
18+
}
1519
}
1620

1721
public actual val Status.statusCode: StatusCode

grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeClientCall.kt

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal class NativeClientCall<Request, Response>(
8181
private var listener: Listener<Response>? = null
8282
private var halfClosed = false
8383
private var cancelled = false
84-
private var closed = atomic(false)
84+
private val closed = atomic(false)
8585

8686
// tracks how many operations are in flight (not yet completed by the listener).
8787
// if 0 and we got a closeInfo (containing the status), there are no more ongoing operations.
@@ -133,7 +133,9 @@ internal class NativeClientCall<Request, Response>(
133133
val lst = checkNotNull(listener) { internalError("Not yet started") }
134134
// allows the managed channel to join for the call to finish.
135135
callJob.complete()
136-
lst.onClose(info.first, info.second)
136+
safeUserCode("Failed to call onClose.") {
137+
lst.onClose(info.first, info.second)
138+
}
137139
}
138140
}
139141

@@ -142,9 +144,8 @@ internal class NativeClientCall<Request, Response>(
142144
* This is called as soon as the RECV_STATUS_ON_CLIENT batch (started with [startRecvStatus]) finished.
143145
*/
144146
private fun markClosePending(status: Status, trailers: GrpcTrailers) {
145-
if (closeInfo.compareAndSet(null, Pair(status, trailers))) {
146-
tryToCloseCall()
147-
}
147+
closeInfo.compareAndSet(null, Pair(status, trailers))
148+
tryToCloseCall()
148149
}
149150

150151
/**
@@ -153,7 +154,9 @@ internal class NativeClientCall<Request, Response>(
153154
*/
154155
private fun turnReady() {
155156
if (ready.compareAndSet(expect = false, update = true)) {
156-
listener?.onReady()
157+
safeUserCode("Failed to call onClose.") {
158+
listener?.onReady()
159+
}
157160
}
158161
}
159162

@@ -163,7 +166,6 @@ internal class NativeClientCall<Request, Response>(
163166
headers: GrpcTrailers,
164167
) {
165168
check(listener == null) { internalError("Already started") }
166-
check(!cancelled) { internalError("Already cancelled.") }
167169

168170
listener = responseListener
169171

@@ -254,7 +256,8 @@ internal class NativeClientCall<Request, Response>(
254256
is BatchResult.Submitted -> {
255257
callResult.future.onComplete {
256258
val details = statusDetails.toByteArray().toKString()
257-
val status = Status(statusCode.value.toKotlin(), details, null)
259+
val kStatusCode = statusCode.value.toKotlin()
260+
val status = Status(kStatusCode, details, null)
258261
val trailers = GrpcTrailers()
259262

260263
// cleanup
@@ -306,7 +309,9 @@ internal class NativeClientCall<Request, Response>(
306309
grpc_metadata_array_destroy(meta.ptr)
307310
arena.clear()
308311
}) {
309-
// TODO: Send headers to listener
312+
safeUserCode("Failed to call onHeaders.") {
313+
listener?.onHeaders(GrpcTrailers())
314+
}
310315
}
311316
}
312317

@@ -319,7 +324,10 @@ internal class NativeClientCall<Request, Response>(
319324
// limit numMessages to prevent potential stack overflows
320325
check(numMessages <= 16) { internalError("numMessages must be <= 16") }
321326
val listener = checkNotNull(listener) { internalError("Not yet started") }
322-
check(!cancelled) { internalError("Already cancelled") }
327+
if (cancelled) {
328+
// no need to send message if the call got already cancelled.
329+
return
330+
}
323331

324332
var remainingMessages = numMessages
325333

@@ -342,7 +350,9 @@ internal class NativeClientCall<Request, Response>(
342350
val buf = recvPtr.value ?: return@runBatch
343351
val msg = methodDescriptor.getResponseMarshaller()
344352
.parse(buf.toKotlin().asInputStream())
345-
listener.onMessage(msg)
353+
safeUserCode("Failed to call onClose.") {
354+
listener.onMessage(msg)
355+
}
346356
post()
347357
}
348358
}
@@ -353,8 +363,11 @@ internal class NativeClientCall<Request, Response>(
353363

354364
override fun cancel(message: String?, cause: Throwable?) {
355365
cancelled = true
356-
val message = if (cause != null) "$message: ${cause.message}" else message
357-
cancelInternal(grpc_status_code.GRPC_STATUS_CANCELLED, message ?: "Call cancelled")
366+
val status = Status(StatusCode.CANCELLED, message ?: "Call cancelled", cause)
367+
// user side cancellation must always win over any other status (even if the call is already completed).
368+
// this will also preserve the cancellation cause, which cannot be passed to the grpc-core.
369+
closeInfo.value = Pair(status, GrpcTrailers())
370+
cancelInternal(grpc_status_code.GRPC_STATUS_CANCELLED, message ?: "Call cancelled with cause: ${cause?.message}")
358371
}
359372

360373
private fun cancelInternal(statusCode: grpc_status_code, message: String) {
@@ -366,7 +379,7 @@ internal class NativeClientCall<Request, Response>(
366379

367380
override fun halfClose() {
368381
check(!halfClosed) { internalError("Already half closed.") }
369-
check(!cancelled) { internalError("Already cancelled.") }
382+
if (cancelled) return
370383
halfClosed = true
371384

372385
val arena = Arena()
@@ -384,9 +397,10 @@ internal class NativeClientCall<Request, Response>(
384397
override fun sendMessage(message: Request) {
385398
checkNotNull(listener) { internalError("Not yet started") }
386399
check(!halfClosed) { internalError("Already half closed.") }
387-
check(!cancelled) { internalError("Already cancelled.") }
388400
check(isReady()) { internalError("Not yet ready.") }
389401

402+
if (cancelled) return
403+
390404
// set ready false, as only one message can be sent at a time.
391405
ready.value = false
392406

@@ -408,6 +422,18 @@ internal class NativeClientCall<Request, Response>(
408422
turnReady()
409423
}
410424
}
425+
426+
/**
427+
* Safely executes the provided block of user code, catching any thrown exceptions or errors.
428+
* If an exception is caught, it cancels the operation with the specified message and cause.
429+
*/
430+
private fun safeUserCode(cancelMsg: String, block: () -> Unit) {
431+
try {
432+
block()
433+
} catch (e: Throwable) {
434+
cancel(cancelMsg, e)
435+
}
436+
}
411437
}
412438

413439

0 commit comments

Comments
 (0)