diff --git a/cinterop-c/MODULE.bazel.lock b/cinterop-c/MODULE.bazel.lock index 4da8305f4..43b3be81c 100644 --- a/cinterop-c/MODULE.bazel.lock +++ b/cinterop-c/MODULE.bazel.lock @@ -245,8 +245,8 @@ "https://bcr.bazel.build/modules/rules_java/7.3.2/MODULE.bazel": "50dece891cfdf1741ea230d001aa9c14398062f2b7c066470accace78e412bc2", "https://bcr.bazel.build/modules/rules_java/7.4.0/MODULE.bazel": "a592852f8a3dd539e82ee6542013bf2cadfc4c6946be8941e189d224500a8934", "https://bcr.bazel.build/modules/rules_java/7.6.1/MODULE.bazel": "2f14b7e8a1aa2f67ae92bc69d1ec0fa8d9f827c4e17ff5e5f02e91caa3b2d0fe", - "https://bcr.bazel.build/modules/rules_java/8.14.0/MODULE.bazel": "717717ed40cc69994596a45aec6ea78135ea434b8402fb91b009b9151dd65615", - "https://bcr.bazel.build/modules/rules_java/8.14.0/source.json": "8a88c4ca9e8759da53cddc88123880565c520503321e2566b4e33d0287a3d4bc", + "https://bcr.bazel.build/modules/rules_java/8.12.0/MODULE.bazel": "8e6590b961f2defdfc2811c089c75716cb2f06c8a4edeb9a8d85eaa64ee2a761", + "https://bcr.bazel.build/modules/rules_java/8.12.0/source.json": "cbd5d55d9d38d4008a7d00bee5b5a5a4b6031fcd4a56515c9accbcd42c7be2ba", "https://bcr.bazel.build/modules/rules_java/8.3.2/MODULE.bazel": "7336d5511ad5af0b8615fdc7477535a2e4e723a357b6713af439fe8cf0195017", "https://bcr.bazel.build/modules/rules_java/8.5.1/MODULE.bazel": "d8a9e38cc5228881f7055a6079f6f7821a073df3744d441978e7a43e20226939", "https://bcr.bazel.build/modules/rules_java/8.6.1/MODULE.bazel": "f4808e2ab5b0197f094cabce9f4b006a27766beb6a9975931da07099560ca9c2", @@ -534,7 +534,7 @@ }, "@@rules_foreign_cc+//foreign_cc:extensions.bzl%tools": { "general": { - "bzlTransitiveDigest": "jO6HNyY7/eIylNs2RYABjCfbAgUNb1oiXpl3aY4V/hI=", + "bzlTransitiveDigest": "ginC3lIGOKKivBi0nyv2igKvSiz42Thm8yaX2RwVaHg=", "usagesDigest": "9LXdVp01HkdYQT8gYPjYLO6VLVJHo9uFfxWaU1ymiRE=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, @@ -848,7 +848,7 @@ }, "@@rules_kotlin+//src/main/starlark/core/repositories:bzlmod_setup.bzl%rules_kotlin_extensions": { "general": { - "bzlTransitiveDigest": "OlvsB0HsvxbR8ZN+J9Vf00X/+WVz/Y/5Xrq2LgcVfdo=", + "bzlTransitiveDigest": "hUTp2w+RUVdL7ma5esCXZJAFnX7vLbVfLd7FwnQI6bU=", "usagesDigest": "QI2z8ZUR+mqtbwsf2fLqYdJAkPOHdOV+tF2yVAUgRzw=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, diff --git a/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcDeclarationScanner.kt b/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcDeclarationScanner.kt index 80cdf8620..259b1b48d 100644 --- a/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcDeclarationScanner.kt +++ b/compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcDeclarationScanner.kt @@ -18,6 +18,7 @@ import org.jetbrains.kotlin.ir.expressions.IrExpression import org.jetbrains.kotlin.ir.util.dumpKotlinLike import org.jetbrains.kotlin.ir.util.getAnnotation import org.jetbrains.kotlin.ir.util.hasDefaultValue +import org.jetbrains.kotlin.ir.util.packageFqName /** * This class scans user declared RPC service @@ -31,7 +32,9 @@ internal object RpcDeclarationScanner { var stubClass: IrClass? = null val grpcAnnotation = service.getAnnotation(RpcClassId.grpcAnnotation.asSingleFqName()) - val protoPackage = grpcAnnotation?.arguments?.getOrNull(0)?.asConstString() ?: "" + // if the protoPackage is not set by the annotation, we use the service kotlin package name + val protoPackage = grpcAnnotation?.arguments?.getOrNull(0)?.asConstString() + ?: service.packageFqName?.asString() ?: "" val declarations = service.declarations.memoryOptimizedMap { declaration -> when (declaration) { diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ClientInterceptor.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ClientInterceptor.kt new file mode 100644 index 000000000..c81d17f90 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ClientInterceptor.kt @@ -0,0 +1,131 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc + +import kotlinx.coroutines.flow.Flow +import kotlinx.rpc.grpc.internal.GrpcCallOptions +import kotlinx.rpc.grpc.internal.MethodDescriptor + +/** + * The scope of a single outgoing gRPC client call observed by a [ClientInterceptor]. + * + * An interceptor receives this scope instance for every call and can: + * - Inspect the RPC [method] being invoked. + * - Read or populate [requestHeaders] before the request is sent. + * - Read [callOptions] that affect transport-level behavior. + * - Register callbacks with [onHeaders] and [onClose] to observe response metadata and final status. + * - Cancel the call early via [cancel]. + * - Continue the call by calling [proceed] with a (possibly transformed) request [Flow]. + * - Transform the response by modifying the returned [Flow]. + * + * ```kt + * val interceptor = object : ClientInterceptor { + * override fun ClientCallScope.intercept( + * request: Flow + * ): Flow { + * // Example: add a header before proceeding + * requestHeaders[MyKeys.Authorization] = token + * + * // Example: observe response metadata + * onHeaders { headers -> /* inspect headers */ } + * onClose { status, trailers -> /* log status/trailers */ } + * + * // IMPORTANT: proceed forwards the call to the next interceptor/transport. + * // If you do not call proceed, no request will be sent and the call is short-circuited. + * return proceed(request) + * } + * } + * ``` + * + * @param Request the request message type of the RPC. + * @param Response the response message type of the RPC. + */ +public interface ClientCallScope { + /** Descriptor of the RPC method (name, marshalling, type) being invoked. */ + public val method: MethodDescriptor + + /** + * Outgoing request headers for this call. + * + * Interceptors may read and mutate this metadata + * before calling [proceed] so the headers are sent to the server. Headers added after + * the call has already been proceeded may not be reflected on the wire. + */ + public val requestHeaders: GrpcMetadata + + /** + * Transport/engine options used for this call (deadlines, compression, etc.). + * Modifying this object is only possible before the call is proceeded. + */ + public val callOptions: GrpcCallOptions + + /** + * Register a callback invoked when the initial response headers are received. + * Typical gRPC semantics guarantee headers are delivered at most once per call + * and before the first message is received. + */ + public fun onHeaders(block: (responseHeaders: GrpcMetadata) -> Unit) + + /** + * Register a callback invoked when the call completes, successfully or not. + * The final `status` and trailing `responseTrailers` are provided. + */ + public fun onClose(block: (status: Status, responseTrailers: GrpcMetadata) -> Unit) + + /** + * Cancel the call locally, providing a human-readable [message] and an optional [cause]. + * This method won't return and abort all further processing. + * + * We made cancel throw a [StatusException] instead of returning, so control flow is explicit and + * race conditions between interceptors and the transport layer are avoided. + */ + public fun cancel(message: String, cause: Throwable? = null): Nothing + + /** + * Continue the invocation by forwarding it to the next interceptor or to the underlying transport. + * + * This function is the heart of an interceptor: + * - It must be called to actually perform the RPC. If you never call [proceed], the request is not sent + * and the call is effectively short-circuited by the interceptor. + * - You may transform the [request] flow before passing it to [proceed] (e.g., logging, retry orchestration, + * compression, metrics). The returned [Flow] yields response messages and can also be transformed + * before being returned to the caller. + * - Call [proceed] at most once per intercepted call. Calling it multiple times or after cancellation + * is not supported. + */ + public fun proceed(request: Flow): Flow +} + +/** + * Client-side interceptor for gRPC calls. + * + * Implementations can observe and modify client calls in a structured way. The primary entry point is the + * [intercept] extension function on [ClientCallScope], which receives the inbound request [Flow] and must + * call [ClientCallScope.proceed] to forward the call. + * + * Common use-cases include: + * - Adding authentication or custom headers. + * - Implementing logging/metrics. + * - Observing headers/trailers and final status. + * - Transforming request/response flows (e.g., mapping, buffering, throttling). + */ +public interface ClientInterceptor { + /** + * Intercept a client call. + * + * You can: + * - Inspect [ClientCallScope.method] and [ClientCallScope.callOptions]. + * - Read or populate [ClientCallScope.requestHeaders]. + * - Register [ClientCallScope.onHeaders] and [ClientCallScope.onClose] callbacks. + * - Transform the [request] flow or wrap the resulting response flow. + * + * IMPORTANT: [ClientCallScope.proceed] must eventually be called to actually execute the RPC and obtain + * the response [Flow]. If [ClientCallScope.proceed] is omitted, the call will not reach the server. + */ + public fun ClientCallScope.intercept( + request: Flow, + ): Flow + +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcClient.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcClient.kt index 8204b6f7f..657a98631 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcClient.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcClient.kt @@ -32,8 +32,9 @@ private typealias RequestClient = Any * @field channel The [ManagedChannel] used to communicate with remote gRPC services. */ public class GrpcClient internal constructor( - private val channel: ManagedChannel, + internal val channel: ManagedChannel, messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver, + internal val interceptors: List, ) : RpcClient { private val delegates = RpcInternalConcurrentHashMap() private val messageCodecResolver = messageCodecResolver + ThrowingMessageCodecResolver @@ -54,11 +55,10 @@ public class GrpcClient internal constructor( override suspend fun call(call: RpcCall): T = withGrpcCall(call) { methodDescriptor, request -> val callOptions = GrpcDefaultCallOptions - val trailers = GrpcTrailers() + val trailers = GrpcMetadata() return when (methodDescriptor.type) { MethodType.UNARY -> unaryRpc( - channel = channel.platformApi, descriptor = methodDescriptor, request = request, callOptions = callOptions, @@ -66,7 +66,6 @@ public class GrpcClient internal constructor( ) MethodType.CLIENT_STREAMING -> @Suppress("UNCHECKED_CAST") clientStreamingRpc( - channel = channel.platformApi, descriptor = methodDescriptor, requests = request as Flow, callOptions = callOptions, @@ -79,11 +78,10 @@ public class GrpcClient internal constructor( override fun callServerStreaming(call: RpcCall): Flow = withGrpcCall(call) { methodDescriptor, request -> val callOptions = GrpcDefaultCallOptions - val trailers = GrpcTrailers() + val trailers = GrpcMetadata() when (methodDescriptor.type) { MethodType.SERVER_STREAMING -> serverStreamingRpc( - channel = channel.platformApi, descriptor = methodDescriptor, request = request, callOptions = callOptions, @@ -91,7 +89,6 @@ public class GrpcClient internal constructor( ) MethodType.BIDI_STREAMING -> @Suppress("UNCHECKED_CAST") bidirectionalStreamingRpc( - channel = channel.platformApi, descriptor = methodDescriptor, requests = request as Flow, callOptions = callOptions, @@ -126,28 +123,164 @@ public class GrpcClient internal constructor( } /** - * Constructor function for the [GrpcClient] class. + * Creates and configures a gRPC client instance. + * + * This function initializes a new gRPC client with the specified target server + * and allows optional customization of the client's configuration through a configuration block. + * + * @param hostname The gRPC server hostname to connect to. + * @param port The gRPC server port to connect to. + * @param configure An optional configuration block to customize the [GrpcClientConfiguration]. + * This can include setting up interceptors, specifying credentials, customizing message codec + * resolution, and overriding default authority. + * + * @return A new instance of [GrpcClient] configured with the specified target and options. + * + * @see [GrpcClientConfiguration] */ public fun GrpcClient( hostname: String, port: Int, - credentials: ClientCredentials? = null, - messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver, - configure: ManagedChannelBuilder<*>.() -> Unit = {}, + configure: GrpcClientConfiguration.() -> Unit = {}, ): GrpcClient { - val channel = ManagedChannelBuilder(hostname, port, credentials).apply(configure).buildChannel() - return GrpcClient(channel, messageCodecResolver) + val config = GrpcClientConfiguration().apply(configure) + return GrpcClient(ManagedChannelBuilder(hostname, port, config.credentials), config) } + /** - * Constructor function for the [GrpcClient] class. + * Creates and configures a gRPC client instance. + * + * This function initializes a new gRPC client with the specified target server + * and allows optional customization of the client's configuration through a configuration block. + * + * @param target The gRPC server endpoint to connect to, typically specified in + * the format `hostname:port`. + * @param configure An optional configuration block to customize the [GrpcClientConfiguration]. + * This can include setting up interceptors, specifying credentials, customizing message codec + * resolution, and overriding default authority. + * + * @return A new instance of [GrpcClient] configured with the specified target and options. + * + * @see [GrpcClientConfiguration] */ public fun GrpcClient( target: String, - credentials: ClientCredentials? = null, - messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver, - configure: ManagedChannelBuilder<*>.() -> Unit = {}, + configure: GrpcClientConfiguration.() -> Unit = {}, +): GrpcClient { + val config = GrpcClientConfiguration().apply(configure) + return GrpcClient(ManagedChannelBuilder(target, config.credentials), config) +} + +private fun GrpcClient( + builder: ManagedChannelBuilder<*>, + config: GrpcClientConfiguration, ): GrpcClient { - val channel = ManagedChannelBuilder(target, credentials).apply(configure).buildChannel() - return GrpcClient(channel, messageCodecResolver) + val channel = builder.apply { + config.overrideAuthority?.let { overrideAuthority(it) } + }.buildChannel() + return GrpcClient(channel, config.messageCodecResolver, config.interceptors) } + + +/** + * Configuration class for a gRPC client, providing customization options + * for client behavior, including interceptors, credentials, codec resolution, + * and authority overrides. + */ +public class GrpcClientConfiguration internal constructor() { + internal val interceptors: MutableList = mutableListOf() + + /** + * Configurable resolver used to determine the appropriate codec for a given Kotlin type + * during message serialization and deserialization in gRPC calls. + * + * Custom implementations of [MessageCodecResolver] can be provided to handle specific serialization + * for arbitrary types. + * For custom types prefer using the [kotlinx.rpc.grpc.codec.WithCodec] annotation. + * + * @see MessageCodecResolver + * @see kotlinx.rpc.grpc.codec.SourcedMessageCodec + * @see kotlinx.rpc.grpc.codec.WithCodec + */ + public var messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver + + + /** + * Configures the client credentials used for secure gRPC requests made by the client. + * + * By default, the client uses default TLS credentials. + * To use custom TLS credentials, use the [tls] constructor function which returns a + * [TlsClientCredentials] instance. + * + * To use plaintext communication, use the [plaintext] constructor function. + * Should only be used for testing or for APIs where the use of such API or + * the data exchanged is not sensitive. + * + * ``` + * GrpcClient("localhost", 50051) { + * credentials = plaintext() // for testing purposes only! + * } + * ``` + */ + public var credentials: ClientCredentials? = null + + /** + * Overrides the authority used with TLS and HTTP virtual hosting. + * It does not change what the host is actually connected to. + * Is commonly in the form `host:port`. + */ + public var overrideAuthority: String? = null + + + /** + * Adds one or more client-side interceptors to the current gRPC client configuration. + * Interceptors enable extended customization of gRPC calls + * by observing or altering the behaviors of requests and responses. + * + * The order of interceptors added via this method is significant. + * Interceptors are executed in the order they are added, + * while one interceptor has to invoke the next interceptor to proceed with the call. + * + * @param interceptors Interceptors to be added to the current configuration. + * Each provided instance of [ClientInterceptor] may perform operations such as modifying headers, + * observing call metadata, logging, or transforming data flows. + * + * @see ClientInterceptor + * @see ClientCallScope + */ + public fun intercept(vararg interceptors: ClientInterceptor) { + this.interceptors.addAll(interceptors) + } + + /** + * Provides insecure client credentials for the gRPC client configuration. + * + * Typically, this would be used for local development, testing, or other + * environments where security is not a concern. + * + * @return An insecure [ClientCredentials] instance that must be passed to [credentials]. + */ + public fun plaintext(): ClientCredentials = createInsecureClientCredentials() + + /** + * Configures and creates secure client credentials for the gRPC client. + * + * This method takes a configuration block in which TLS-related parameters, + * such as trust managers and key managers, can be defined. The resulting + * credentials are used to establish secure communication between the gRPC client + * and server, ensuring encrypted transmission of data and mutual authentication + * if configured. + * + * Alternatively, you can use the [TlsClientCredentials] constructor. + * + * @param configure A configuration block that allows setting up the TLS parameters + * using the [TlsClientCredentialsBuilder]. + * @return A secure [ClientCredentials] instance that must be passed to [credentials]. + * + * @see credentials + */ + public fun tls(configure: TlsClientCredentialsBuilder.() -> Unit): ClientCredentials = + TlsClientCredentials(configure) + +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.kt similarity index 67% rename from grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.kt rename to grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.kt index e8f2fb903..d5f5749e3 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.kt @@ -5,6 +5,6 @@ package kotlinx.rpc.grpc @Suppress("RedundantConstructorKeyword") -public expect class GrpcTrailers constructor() { - public fun merge(trailers: GrpcTrailers) +public expect class GrpcMetadata constructor() { + public fun merge(trailers: GrpcMetadata) } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcServer.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcServer.kt index 1a73e9c74..dc21c8c6d 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcServer.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcServer.kt @@ -44,14 +44,14 @@ private typealias ResponseServer = Any * * @property port Specifies the port used by the server to listen for incoming connections. * @param parentContext - * @param configure exposes platform-specific Server builder. + * @param serverBuilder exposes platform-specific Server builder. */ public class GrpcServer internal constructor( - override val port: Int = 8080, - credentials: ServerCredentials? = null, + override val port: Int, + private val serverBuilder: ServerBuilder<*>, + private val interceptors: List, messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver, parentContext: CoroutineContext = EmptyCoroutineContext, - configure: ServerBuilder<*>.() -> Unit, ) : RpcServer, Server { private val internalContext = SupervisorJob(parentContext[Job]) private val internalScope = CoroutineScope(parentContext + internalContext) @@ -61,9 +61,8 @@ public class GrpcServer internal constructor( private var isBuilt = false private lateinit var internalServer: Server - private val serverBuilder: ServerBuilder<*> = ServerBuilder(port, credentials).apply(configure) private val registry: MutableHandlerRegistry by lazy { - MutableHandlerRegistry().apply { serverBuilder.fallbackHandlerRegistry(this) } + MutableHandlerRegistry().apply { this@GrpcServer.serverBuilder.fallbackHandlerRegistry(this) } } private val localRegistry = RpcInternalConcurrentHashMap, ServerServiceDefinition>() @@ -79,7 +78,7 @@ public class GrpcServer internal constructor( if (isBuilt) { registry.addService(definition) } else { - serverBuilder.addService(definition) + this@GrpcServer.serverBuilder.addService(definition) } } @@ -105,7 +104,8 @@ public class GrpcServer internal constructor( as? MethodDescriptor ?: error("Expected a gRPC method descriptor") - it.toDefinitionOn(methodDescriptor, service) + // TODO: support per service and per method interceptors (KRPC-222) + it.toDefinitionOn(methodDescriptor, service, interceptors) } return serverServiceDefinition(delegate.serviceDescriptor, methods) @@ -114,29 +114,40 @@ public class GrpcServer internal constructor( private fun <@Grpc Service : Any> RpcCallable.toDefinitionOn( descriptor: MethodDescriptor, service: Service, + interceptors: List, ): ServerMethodDefinition { return when (descriptor.type) { MethodType.UNARY -> { - internalScope.unaryServerMethodDefinition(descriptor, returnType.kType) { request -> + internalScope.unaryServerMethodDefinition(descriptor, returnType.kType, interceptors) { request -> unaryInvokator.call(service, arrayOf(request)) as ResponseServer } } MethodType.CLIENT_STREAMING -> { - internalScope.clientStreamingServerMethodDefinition(descriptor, returnType.kType) { requests -> + internalScope.clientStreamingServerMethodDefinition( + descriptor, + returnType.kType, + interceptors + ) { requests -> unaryInvokator.call(service, arrayOf(requests)) as ResponseServer } } MethodType.SERVER_STREAMING -> { - internalScope.serverStreamingServerMethodDefinition(descriptor, returnType.kType) { request -> + internalScope.serverStreamingServerMethodDefinition( + descriptor, returnType.kType, interceptors + ) { request -> @Suppress("UNCHECKED_CAST") flowInvokator.call(service, arrayOf(request)) as Flow } } MethodType.BIDI_STREAMING -> { - internalScope.bidiStreamingServerMethodDefinition(descriptor, returnType.kType) { requests -> + internalScope.bidiStreamingServerMethodDefinition( + descriptor, + returnType.kType, + interceptors + ) { requests -> @Suppress("UNCHECKED_CAST") flowInvokator.call(service, arrayOf(requests)) as Flow } @@ -152,7 +163,7 @@ public class GrpcServer internal constructor( internal fun build() { if (buildLock.compareAndSet(expect = false, update = true)) { - internalServer = Server(serverBuilder) + internalServer = Server(this@GrpcServer.serverBuilder) isBuilt = true } } @@ -188,16 +199,145 @@ public class GrpcServer internal constructor( } /** - * Constructor function for the [GrpcServer] class. + * Creates and configures a gRPC server instance. + * + * This function initializes a gRPC server with the provided port and a configuration block + * ([GrpcServerConfiguration]). + * + * To start the server, call the [GrpcServer.start] method. + * To clean up resources, call the [GrpcServer.shutdown] or [GrpcServer.shutdownNow] methods. + * + * ```kt + * GrpcServer(port) { + * credentials = tls(myCertChain, myPrivateKey) + * services { + * registerService { MyServiceImpl() } + * registerService { MyOtherServiceImpl() } + * } + * } + * ``` + * + * @param port The port number where the gRPC server will listen for incoming connections. + * This must be a valid and available port on the host system. + * @param parentContext The parent coroutine context used for managing server-related operations. + * Defaults to an empty coroutine context if not specified. + * @param configure A configuration lambda receiver, + * allowing customization of server behavior such as credentials, interceptors, + * codecs, and service registration logic. + * @return A fully configured `GrpcServer` instance, which must be started explicitly to handle requests. */ public fun GrpcServer( port: Int, - credentials: ServerCredentials? = null, - messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver, parentContext: CoroutineContext = EmptyCoroutineContext, - configure: ServerBuilder<*>.() -> Unit = {}, - builder: RpcServer.() -> Unit = {}, + configure: GrpcServerConfiguration.() -> Unit = {}, ): GrpcServer { - return GrpcServer(port, credentials, messageCodecResolver, parentContext, configure).apply(builder) + val config = GrpcServerConfiguration().apply(configure) + val serverBuilder = ServerBuilder(port, config.credentials).apply { + config.fallbackHandlerRegistry?.let { fallbackHandlerRegistry(it) } + } + return GrpcServer(port, serverBuilder, config.interceptors, config.messageCodecResolver, parentContext) + .apply(config.serviceBuilder) .apply { build() } } + +/** + * A configuration class for setting up a gRPC server. + * + * This class provides an API to configure various server parameters, such as message codecs, + * security credentials, server-side interceptors, and service registration. + */ +public class GrpcServerConfiguration internal constructor() { + + internal val interceptors: MutableList = mutableListOf() + internal var serviceBuilder: RpcServer.() -> Unit = { } + + + /** + * Sets the credentials to be used by the gRPC server for secure communication. + * + * By default, the server does not have any credentials configured and the communication is plaintext. + * To set up transport-layer security provide a [TlsServerCredentials] by constructing it with the + * [tls] function. + * + * @see TlsServerCredentials + * @see tls + */ + public var credentials: ServerCredentials? = null + + /** + * Sets a custom [MessageCodecResolver] to be used by the gRPC server for resolving the appropriate + * codec for message serialization and deserialization. + * + * When not explicitly set, a default [EmptyMessageCodecResolver] is used, which may not perform + * any specific resolution. + * Provide a custom [MessageCodecResolver] to resolve codecs based on the message's `KType`. + */ + public var messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver + + + /** + * Sets a custom [HandlerRegistry] to be used by the gRPC server for resolving service implementations + * that were not registered before via the [services] configuration block. + * + * If not set, unknown services not registered will cause a `UNIMPLEMENTED` status + * to be returned to the client. + */ + public var fallbackHandlerRegistry: HandlerRegistry? = null + + /** + * Registers one or more server-side interceptors for the gRPC server. + * + * Interceptors allow observing and modifying incoming gRPC calls before they reach the service + * implementation logic. + * They are commonly used to implement cross-cutting concerns like + * authentication, logging, metrics, or custom request/response transformations. + * + * @param interceptors One or more instances of [ServerInterceptor] to be applied to incoming calls. + * @see ServerInterceptor + */ + public fun intercept(vararg interceptors: ServerInterceptor) { + this.interceptors.addAll(interceptors) + } + + /** + * Configures the gRPC server to register services. + * + * This method allows defining a block of logic to configure an [RpcServer] instance, + * where multiple services can be registered: + * ```kt + * GrpcServer(port) { + * services { + * registerService { MyServiceImpl() } + * registerService { MyOtherServiceImpl() } + * } + * } + * ``` + * + * @param block A lambda with [RpcServer] as its receiver, allowing service registration. + */ + public fun services(block: RpcServer.() -> Unit) { + serviceBuilder = block + } + + /** + * Configures and creates TLS (Transport Layer Security) credentials for the gRPC server. + * + * This method allows specifying the server's certificate chain, private key, and additional + * configurations needed for setting up a secure communication channel over TLS. + * + * @param certificateChain A string representing the PEM-encoded certificate chain for the server. + * @param privateKey A string representing the PKCS#8 formatted private key corresponding to the certificate. + * @param configure A lambda to further customize the [TlsServerCredentialsBuilder], enabling configurations + * like setting trusted root certificates or enabling client authentication. + * @return An instance of [ServerCredentials] representing the configured TLS credentials that must be passed + * to [credentials]. + * + * @see credentials + */ + public fun tls( + certificateChain: String, + privateKey: String, + configure: TlsServerCredentialsBuilder.() -> Unit, + ): ServerCredentials = + TlsServerCredentials(certificateChain, privateKey, configure) +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.kt index bcfa015f5..ea10366d3 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.kt @@ -68,8 +68,6 @@ public interface ManagedChannel { * Builder class for [ManagedChannel]. */ public expect abstract class ManagedChannelBuilder> { - public fun usePlaintext(): T - public abstract fun overrideAuthority(authority: String): T } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ServerInterceptor.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ServerInterceptor.kt new file mode 100644 index 000000000..5f24f09f8 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/ServerInterceptor.kt @@ -0,0 +1,143 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.rpc.grpc.internal.GrpcContext +import kotlinx.rpc.grpc.internal.MethodDescriptor + +/** + * Th scope of a single incoming gRPC server call observed by a [ServerInterceptor]. + * + * An interceptor receives this scope instance for every RPC invocation arriving to the server and can: + * - Inspect the target RPC [method]. + * - Read client-provided [requestHeaders]. + * - Populate [responseHeaders] (sent before the first response message) and [responseTrailers] + * (sent when the call completes). + * - Register a completion callback with [onClose]. + * - Abort the call early with [close]. + * - Continue handling by calling [proceed] with the inbound request [Flow] and optionally transform + * the returned response [Flow]. + * + * @param Request the request message type of the RPC. + * @param Response the response message type of the RPC. + */ +public interface ServerCallScope { + /** Descriptor of the RPC method (name, marshalling, type) being executed. */ + public val method: MethodDescriptor + + /** Metadata received from the client with the initial request headers. Read-only from the server perspective. */ + public val requestHeaders: GrpcMetadata + + /** + * Initial response headers to be sent to the client. + * Interceptors and handlers may add entries before the first response element is emitted + * (i.e., before proceeding or before producing output), otherwise headers might have already been sent. + */ + public val responseHeaders: GrpcMetadata + + /** + * Trailing metadata to be sent with the final status when the call completes. + * Interceptors can add diagnostics or custom metadata here. + */ + public val responseTrailers: GrpcMetadata + + /** + * The [GrpcContext] associated with this call. + * + * It can be used by the interceptor to provide call-scoped information about + * the current call, such as the identity of the caller or the current authentication state. + */ + public val context: GrpcContext + + /** + * Register a callback invoked when the call is closed (successfully or exceptionally). + * Provides the final [Status] and the sent [GrpcMetadata] trailers. + */ + public fun onClose(block: (Status, GrpcMetadata) -> Unit) + + /** + * Immediately terminate the call with the given [status] and optional [trailers]. + * + * This method does not return (declared as [Nothing]). After calling it, no further messages will be processed + * or sent. Prefer setting [responseHeaders]/[responseTrailers] before closing if you need to include metadata. + * + * We made close throw a [StatusException] instead of returning, so control flow is explicit and race conditions + * between interceptors and the service implementation are avoided. + */ + public fun close(status: Status, trailers: GrpcMetadata = GrpcMetadata()): Nothing + + /** + * Continue processing by forwarding the request to the next interceptor or the actual service implementation. + * + * IMPORTANT: + * - You must call [proceed] exactly once to actually handle the RPC; otherwise, the call will be short-circuited + * and the service method will not be invoked. + * - You may transform the incoming [request] flow (e.g., validation, logging, metering) before passing it to + * [proceed]. You may also transform the resulting response [Flow] before returning it to the framework. + * - The interceptor must ensure to provide and return a valid number of messages, depending on the method type. + * - The interceptor must not throw an exception. Use [close] to terminate the call with an error. + */ + public fun proceed(request: Flow): Flow + + /** + * Convenience for flow builders: proceeds with [request] and emits the resulting response elements into this + * [FlowCollector]. Useful inside `flow {}` blocks within interceptors. + * + * ``` + * val myAuthInterceptor = object : ServerInterceptor { + * override fun ServerCallScope.intercept(request: Flow): Flow = + * flow { + * val authorized = mySuspendAuth(requestHeaders) + * if (!authorized) { + * close(Status(StatusCode.PERMISSION_DENIED, "Not authorized")) + * } + * + * proceedUnmodified(request) + * } + * } + * ``` + */ + public suspend fun FlowCollector.proceedUnmodified(request: Flow) { + proceed(request).collect { + emit(it) + } + } +} + +/** + * Server-side interceptor for gRPC calls. + * + * Implementations can observe and modify server handling in a structured way. The entry point is the + * [intercept] extension function on [ServerCallScope], which receives the inbound request [Flow] and must + * call [ServerCallScope.proceed] to forward the call to the next interceptor or the target service method. + * + * Common use-cases include: + * - Authentication/authorization checks and context propagation. + * - Setting response headers and trailers. + * - Structured logging and metrics. + * - Transforming request/response flows (e.g., validation, mapping, throttling). + * + * See ServerInterceptorTest for practical usage patterns. + */ +public interface ServerInterceptor { + /** + * Intercept a server call. + * + * You can: + * - Inspect [ServerCallScope.method]. + * - Read [ServerCallScope.requestHeaders] and populate [ServerCallScope.responseHeaders]/[ServerCallScope.responseTrailers]. + * - Register [ServerCallScope.onClose] callbacks. + * - Transform the [request] flow or wrap the resulting response flow. + * - Append information to the [ServerCallScope.context]. + * + * IMPORTANT: You must eventually call [ServerCallScope.proceed] to actually invoke the service logic and produce + * the response [Flow]. If [ServerCallScope.proceed] is omitted, the call will never reach the service. + */ + public fun ServerCallScope.intercept( + request: Flow, + ): Flow +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/StatusException.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/StatusException.kt index 0cc1fd20b..eae3df45e 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/StatusException.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/StatusException.kt @@ -9,16 +9,16 @@ package kotlinx.rpc.grpc */ public expect class StatusException : Exception { public constructor(status: Status) - public constructor(status: Status, trailers: GrpcTrailers?) + public constructor(status: Status, trailers: GrpcMetadata?) public fun getStatus(): Status - public fun getTrailers(): GrpcTrailers? + public fun getTrailers(): GrpcMetadata? } public expect class StatusRuntimeException : RuntimeException { public constructor(status: Status) - public constructor(status: Status, trailers: GrpcTrailers?) + public constructor(status: Status, trailers: GrpcMetadata?) public fun getStatus(): Status - public fun getTrailers(): GrpcTrailers? + public fun getTrailers(): GrpcMetadata? } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/credentials.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/credentials.kt index dcebf9d5f..22717ed3b 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/credentials.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/credentials.kt @@ -10,6 +10,11 @@ public expect abstract class ServerCredentials public expect class InsecureClientCredentials : ClientCredentials public expect class InsecureServerCredentials : ServerCredentials +// we need a wrapper for InsecureChannelCredentials as our constructor would conflict with the private +// java constructor. +internal expect fun createInsecureClientCredentials(): ClientCredentials +internal expect fun createInsecureServerCredentials(): ServerCredentials + public expect class TlsClientCredentials : ClientCredentials public expect class TlsServerCredentials : ServerCredentials diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/CallbackFuture.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/CallbackFuture.kt similarity index 95% rename from grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/CallbackFuture.kt rename to grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/CallbackFuture.kt index 6dbdb5b74..1cd0b907b 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/CallbackFuture.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/CallbackFuture.kt @@ -46,4 +46,6 @@ internal class CallbackFuture { } } } + + val isCompleted: Boolean get() = state.value is State.Done } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.kt index 2c3146bd2..7eac10ff4 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.kt @@ -4,7 +4,7 @@ package kotlinx.rpc.grpc.internal -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.Status import kotlinx.rpc.internal.utils.InternalRpcApi @@ -31,13 +31,13 @@ import kotlinx.rpc.internal.utils.InternalRpcApi public expect abstract class ClientCall { @InternalRpcApi public abstract class Listener { - public open fun onHeaders(headers: GrpcTrailers) + public open fun onHeaders(headers: GrpcMetadata) public open fun onMessage(message: Message) - public open fun onClose(status: Status, trailers: GrpcTrailers) + public open fun onClose(status: Status, trailers: GrpcMetadata) public open fun onReady() } - public abstract fun start(responseListener: Listener, headers: GrpcTrailers) + public abstract fun start(responseListener: Listener, headers: GrpcMetadata) public abstract fun request(numMessages: Int) public abstract fun cancel(message: String?, cause: Throwable?) public abstract fun halfClose() @@ -47,8 +47,8 @@ public expect abstract class ClientCall { @InternalRpcApi public expect fun clientCallListener( - onHeaders: (headers: GrpcTrailers) -> Unit, + onHeaders: (headers: GrpcMetadata) -> Unit, onMessage: (message: Message) -> Unit, - onClose: (status: Status, trailers: GrpcTrailers) -> Unit, + onClose: (status: Status, trailers: GrpcMetadata) -> Unit, onReady: () -> Unit, ): ClientCall.Listener diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt index 2ec2e4487..2e624556d 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt @@ -6,10 +6,13 @@ package kotlinx.rpc.grpc.internal import kotlin.coroutines.CoroutineContext -internal expect class GrpcContext +public expect class GrpcContext + internal expect val CurrentGrpcContext: GrpcContext internal expect class GrpcContextElement : CoroutineContext.Element { + val grpcContext: GrpcContext + companion object Key : CoroutineContext.Key { fun current(): GrpcContextElement } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.kt index 8b390e756..cdd32aacd 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.kt @@ -4,13 +4,13 @@ package kotlinx.rpc.grpc.internal -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.Status import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi public expect fun interface ServerCallHandler { - public fun startCall(call: ServerCall, headers: GrpcTrailers): ServerCall.Listener + public fun startCall(call: ServerCall, headers: GrpcMetadata): ServerCall.Listener } @InternalRpcApi @@ -25,9 +25,9 @@ public expect abstract class ServerCall { } public abstract fun request(numMessages: Int) - public abstract fun sendHeaders(headers: GrpcTrailers) + public abstract fun sendHeaders(headers: GrpcMetadata) public abstract fun sendMessage(message: Response) - public abstract fun close(status: Status, trailers: GrpcTrailers) + public abstract fun close(status: Status, trailers: GrpcMetadata) public open fun isReady(): Boolean public abstract fun isCancelled(): Boolean diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendClientCalls.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendClientCalls.kt index 39635bdc9..0621b9810 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendClientCalls.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendClientCalls.kt @@ -4,25 +4,37 @@ package kotlinx.rpc.grpc.internal -import kotlinx.coroutines.* +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.onFailure +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.single -import kotlinx.rpc.grpc.* +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlinx.rpc.grpc.ClientCallScope +import kotlinx.rpc.grpc.GrpcClient +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.StatusException +import kotlinx.rpc.grpc.statusCode import kotlinx.rpc.internal.utils.InternalRpcApi // heavily inspired by // https://github.com/grpc/grpc-kotlin/blob/master/stub/src/main/java/io/grpc/kotlin/ClientCalls.kt @InternalRpcApi -public suspend fun unaryRpc( - channel: GrpcChannel, +public suspend fun GrpcClient.unaryRpc( descriptor: MethodDescriptor, request: Request, callOptions: GrpcCallOptions = GrpcDefaultCallOptions, - trailers: GrpcTrailers = GrpcTrailers(), + trailers: GrpcMetadata = GrpcMetadata(), ): Response { val type = descriptor.type require(type == MethodType.UNARY) { @@ -30,21 +42,19 @@ public suspend fun unaryRpc( } return rpcImpl( - channel = channel, descriptor = descriptor, callOptions = callOptions, trailers = trailers, - request = ClientRequest.Unary(request) + request = flowOf(request) ).singleOrStatus("request", descriptor) } @InternalRpcApi -public fun serverStreamingRpc( - channel: GrpcChannel, +public fun GrpcClient.serverStreamingRpc( descriptor: MethodDescriptor, request: Request, callOptions: GrpcCallOptions = GrpcDefaultCallOptions, - trailers: GrpcTrailers = GrpcTrailers(), + trailers: GrpcMetadata = GrpcMetadata(), ): Flow { val type = descriptor.type require(type == MethodType.SERVER_STREAMING) { @@ -52,21 +62,19 @@ public fun serverStreamingRpc( } return rpcImpl( - channel = channel, descriptor = descriptor, callOptions = callOptions, trailers = trailers, - request = ClientRequest.Unary(request) + request = flowOf(request) ) } @InternalRpcApi -public suspend fun clientStreamingRpc( - channel: GrpcChannel, +public suspend fun GrpcClient.clientStreamingRpc( descriptor: MethodDescriptor, requests: Flow, callOptions: GrpcCallOptions = GrpcDefaultCallOptions, - trailers: GrpcTrailers = GrpcTrailers(), + trailers: GrpcMetadata = GrpcMetadata(), ): Response { val type = descriptor.type require(type == MethodType.CLIENT_STREAMING) { @@ -74,21 +82,19 @@ public suspend fun clientStreamingRpc( } return rpcImpl( - channel = channel, descriptor = descriptor, callOptions = callOptions, trailers = trailers, - request = ClientRequest.Flowing(requests) + request = requests ).singleOrStatus("response", descriptor) } @InternalRpcApi -public fun bidirectionalStreamingRpc( - channel: GrpcChannel, +public fun GrpcClient.bidirectionalStreamingRpc( descriptor: MethodDescriptor, requests: Flow, callOptions: GrpcCallOptions = GrpcDefaultCallOptions, - trailers: GrpcTrailers = GrpcTrailers(), + trailers: GrpcMetadata = GrpcMetadata(), ): Flow { val type = descriptor.type check(type == MethodType.BIDI_STREAMING) { @@ -96,11 +102,10 @@ public fun bidirectionalStreamingRpc( } return rpcImpl( - channel = channel, descriptor = descriptor, callOptions = callOptions, trailers = trailers, - request = ClientRequest.Flowing(requests) + request = requests ) } @@ -133,86 +138,21 @@ private sealed interface ClientRequest { } } -private fun rpcImpl( - channel: GrpcChannel, +private fun GrpcClient.rpcImpl( descriptor: MethodDescriptor, callOptions: GrpcCallOptions, - trailers: GrpcTrailers, - request: ClientRequest, -): Flow = flow { - coroutineScope { - val handler = channel.newCall(descriptor, callOptions) - - /* - * We maintain a buffer of size 1 so onMessage never has to block: it only gets called after - * we request a response from the server, which only happens when responses is empty and - * there is room in the buffer. - */ - val responses = Channel(1) - val ready = Ready { handler.isReady() } - - handler.start(channelResponseListener(responses, ready), trailers) - - val fullMethodName = descriptor.getFullMethodName() - val sender = launch(CoroutineName("grpc-send-message-$fullMethodName")) { - try { - request.sendTo(handler, ready) - handler.halfClose() - } catch (ex: Exception) { - handler.cancel("Collection of requests completed exceptionally", ex) - throw ex // propagate failure upward - } - } - - try { - handler.request(1) - for (response in responses) { - emit(response) - handler.request(1) - } - } catch (e: Exception) { - withContext(NonCancellable) { - sender.cancel("Collection of responses completed exceptionally", e) - sender.join() - // we want the sender to be done cancelling before we cancel the handler, or it might try - // sending to a dead call, which results in ugly exception messages - handler.cancel("Collection of responses completed exceptionally", e) - } - throw e - } - - if (!sender.isCompleted) { - sender.cancel("Collection of responses completed before collection of requests") - } - } + trailers: GrpcMetadata, + request: Flow, +): Flow { + val clientCallScope = ClientCallScopeImpl( + client = this, + method = descriptor, + requestHeaders = trailers, + callOptions = callOptions, + ) + return clientCallScope.proceed(request) } -private fun channelResponseListener( - responses: Channel, - ready: Ready, -) = clientCallListener( - onHeaders = { - // todo check what happens here - }, - onMessage = { message: Response -> - responses.trySend(message).onFailure { e -> - throw e ?: AssertionError("onMessage should never be called until responses is ready") - } - }, - onClose = { status: Status, trailers: GrpcTrailers -> - val cause = when { - status.statusCode == StatusCode.OK -> null - status.getCause() is CancellationException -> status.getCause() - else -> StatusException(status, trailers) - } - - responses.close(cause = cause) - }, - onReady = { - ready.onReady() - }, -) - // todo really needed? internal fun Flow.singleOrStatusFlow( expected: String, @@ -261,3 +201,142 @@ internal class Ready(private val isReallyReady: () -> Boolean) { } } } + +private class ClientCallScopeImpl( + val client: GrpcClient, + override val method: MethodDescriptor, + override val requestHeaders: GrpcMetadata, + override val callOptions: GrpcCallOptions, +) : ClientCallScope { + + val call = client.channel.platformApi.newCall(method, callOptions) + val interceptors = client.interceptors + val onHeadersFuture = CallbackFuture() + val onCloseFuture = CallbackFuture>() + + var interceptorIndex = 0 + + override fun onHeaders(block: (GrpcMetadata) -> Unit) { + onHeadersFuture.onComplete { block(it) } + } + + override fun onClose(block: (Status, GrpcMetadata) -> Unit) { + onCloseFuture.onComplete { block(it.first, it.second) } + } + + override fun cancel(message: String, cause: Throwable?): Nothing { + throw StatusException(Status(StatusCode.CANCELLED, message, cause)) + } + + override fun proceed(request: Flow): Flow { + return if (interceptorIndex < interceptors.size) { + with(interceptors[interceptorIndex++]) { + intercept(request) + } + } else { + // if the interceptor chain is exhausted, we start the actual call + doCall(request) + } + } + + private fun doCall(request: Flow): Flow = flow { + coroutineScope { + + /* + * We maintain a buffer of size 1 so onMessage never has to block: it only gets called after + * we request a response from the server, which only happens when responses is empty and + * there is room in the buffer. + */ + val responses = Channel(1) + val ready = Ready { call.isReady() } + + call.start(channelResponseListener(responses, ready), requestHeaders) + + suspend fun Flow.send() { + if (method.type == MethodType.UNARY || method.type == MethodType.SERVER_STREAMING) { + call.sendMessage(single()) + } else { + ready.suspendUntilReady() + this.collect { request -> + call.sendMessage(request) + ready.suspendUntilReady() + } + } + } + + val fullMethodName = method.getFullMethodName() + val sender = launch(CoroutineName("grpc-send-message-$fullMethodName")) { + try { + request.send() + call.halfClose() + } catch (ex: Exception) { + call.cancel("Collection of requests completed exceptionally", ex) + throw ex // propagate failure upward + } + } + + try { + call.request(1) + for (response in responses) { + emit(response) + call.request(1) + } + } catch (e: Exception) { + withContext(NonCancellable) { + sender.cancel("Collection of responses completed exceptionally", e) + sender.join() + // we want the sender to be done cancelling before we cancel the handler, or it might try + // sending to a dead call, which results in ugly exception messages + call.cancel("Collection of responses completed exceptionally", e) + } + throw e + } + + if (!sender.isCompleted) { + sender.cancel("Collection of responses completed before collection of requests") + } + } + } + + private fun channelResponseListener( + responses: Channel, + ready: Ready, + ) = clientCallListener( + onHeaders = { + try { + onHeadersFuture.complete(it) + } catch (e: StatusException) { + // if a client interceptor called cancel, we throw a StatusException. + // as the JVM implementation treats them differently, we need to catch them here. + call.cancel(e.message, e.cause) + } + }, + onMessage = { message: Response -> + responses.trySend(message).onFailure { e -> + throw e ?: AssertionError("onMessage should never be called until responses is ready") + } + }, + onClose = { status: Status, trailers: GrpcMetadata -> + var cause = when { + status.statusCode == StatusCode.OK -> null + status.getCause() is CancellationException -> status.getCause() + else -> StatusException(status, trailers) + } + + try { + onCloseFuture.complete(status to trailers) + } catch (exception: Throwable) { + cause = exception + if (exception !is StatusException) { + val status = Status(StatusCode.CANCELLED, "Interceptor threw an error", exception) + cause = StatusException(status) + } + } + + responses.close(cause = cause) + }, + onReady = { + ready.onReady() + }, + ) +} diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendServerCalls.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendServerCalls.kt index 2c16c6193..b56076544 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendServerCalls.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/suspendServerCalls.kt @@ -16,7 +16,9 @@ import kotlinx.coroutines.flow.map import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.ServerCallScope +import kotlinx.rpc.grpc.ServerInterceptor import kotlinx.rpc.grpc.Status import kotlinx.rpc.grpc.StatusCode import kotlinx.rpc.grpc.StatusException @@ -29,6 +31,7 @@ import kotlin.reflect.typeOf public fun CoroutineScope.unaryServerMethodDefinition( descriptor: MethodDescriptor, responseKType: KType, + interceptors: List, implementation: suspend (request: Request) -> Response, ): ServerMethodDefinition { val type = descriptor.type @@ -36,7 +39,7 @@ public fun CoroutineScope.unaryServerMethodDefinition( "Expected a unary method descriptor but got $descriptor" } - return serverMethodDefinition(descriptor, responseKType) { requests -> + return serverMethodDefinition(descriptor, responseKType, interceptors) { requests -> requests .singleOrStatusFlow("request", descriptor) .map { implementation(it) } @@ -47,6 +50,7 @@ public fun CoroutineScope.unaryServerMethodDefinition( public fun CoroutineScope.clientStreamingServerMethodDefinition( descriptor: MethodDescriptor, responseKType: KType, + interceptors: List, implementation: suspend (requests: Flow) -> Response, ): ServerMethodDefinition { val type = descriptor.type @@ -54,7 +58,7 @@ public fun CoroutineScope.clientStreamingServerMethodDefinit "Expected a client streaming method descriptor but got $descriptor" } - return serverMethodDefinition(descriptor, responseKType) { requests -> + return serverMethodDefinition(descriptor, responseKType, interceptors) { requests -> flow { val response = implementation(requests) emit(response) @@ -66,6 +70,7 @@ public fun CoroutineScope.clientStreamingServerMethodDefinit public fun CoroutineScope.serverStreamingServerMethodDefinition( descriptor: MethodDescriptor, responseKType: KType, + interceptors: List, implementation: (request: Request) -> Flow, ): ServerMethodDefinition { val type = descriptor.type @@ -73,7 +78,7 @@ public fun CoroutineScope.serverStreamingServerMethodDefinit "Expected a server streaming method descriptor but got $descriptor" } - return serverMethodDefinition(descriptor, responseKType) { requests -> + return serverMethodDefinition(descriptor, responseKType, interceptors) { requests -> flow { requests .singleOrStatusFlow("request", descriptor) @@ -90,6 +95,7 @@ public fun CoroutineScope.serverStreamingServerMethodDefinit public fun CoroutineScope.bidiStreamingServerMethodDefinition( descriptor: MethodDescriptor, responseKType: KType, + interceptors: List, implementation: (requests: Flow) -> Flow, ): ServerMethodDefinition { val type = descriptor.type @@ -97,29 +103,36 @@ public fun CoroutineScope.bidiStreamingServerMethodDefinitio "Expected a bidi streaming method descriptor but got $descriptor" } - return serverMethodDefinition(descriptor, responseKType, implementation) + return serverMethodDefinition(descriptor, responseKType, interceptors, implementation) } private fun CoroutineScope.serverMethodDefinition( descriptor: MethodDescriptor, responseKType: KType, + interceptors: List, implementation: (Flow) -> Flow, -): ServerMethodDefinition = serverMethodDefinition(descriptor, serverCallHandler(responseKType, implementation)) +): ServerMethodDefinition = + serverMethodDefinition(descriptor, serverCallHandler(descriptor, responseKType, interceptors, implementation)) private fun CoroutineScope.serverCallHandler( + descriptor: MethodDescriptor, responseKType: KType, + interceptors: List, implementation: (Flow) -> Flow, ): ServerCallHandler = - ServerCallHandler { call, _ -> - serverCallListenerImpl(call, responseKType, implementation) + ServerCallHandler { call, headers -> + serverCallListenerImpl(descriptor, call, responseKType, interceptors, implementation, headers) } private fun CoroutineScope.serverCallListenerImpl( + descriptor: MethodDescriptor, handler: ServerCall, responseKType: KType, + interceptors: List, implementation: (Flow) -> Flow, + requestHeaders: GrpcMetadata, ): ServerCall.Listener { - val ready = Ready { handler.isReady()} + val ready = Ready { handler.isReady() } val requestsChannel = Channel(1) val requestsStarted = AtomicBoolean(false) // enforces read-once @@ -144,11 +157,21 @@ private fun CoroutineScope.serverCallListenerImpl( } } - val rpcJob = launch(GrpcContextElement.current()) { + val context = GrpcContextElement.current() + val serverCallScope = ServerCallScopeImpl( + method = descriptor, + interceptors = interceptors, + implementation = implementation, + requestHeaders = requestHeaders, + serverCall = handler, + context = context.grpcContext, + ) + + val rpcJob = launch(context) { val mutex = Mutex() val headersSent = AtomicBoolean(false) // enforces only sending headers once val failure = runCatching { - implementation(requests).collect { response -> + serverCallScope.proceed(requests).collect { response -> @Suppress("UNCHECKED_CAST") // fix for KRPC-173 val value = if (responseKType == unitKType) Unit as Response else response @@ -156,7 +179,7 @@ private fun CoroutineScope.serverCallListenerImpl( // once we have a response message, check if we've sent headers yet - if not, do so if (headersSent.value.compareAndSet(expect = false, update = true)) { mutex.withLock { - handler.sendHeaders(GrpcTrailers()) + handler.sendHeaders(GrpcMetadata()) } } ready.suspendUntilReady() @@ -169,7 +192,7 @@ private fun CoroutineScope.serverCallListenerImpl( // no elements or threw an exception, then we wouldn't have sent them if (failure == null && headersSent.value.compareAndSet(expect = false, update = true)) { mutex.withLock { - handler.sendHeaders(GrpcTrailers()) + handler.sendHeaders(GrpcMetadata()) } } @@ -195,10 +218,11 @@ private fun CoroutineScope.serverCallListenerImpl( null } } - } ?: GrpcTrailers() + } ?: GrpcMetadata() mutex.withLock { handler.close(closeStatus, trailers) + serverCallScope.onCloseFuture.complete(Pair(closeStatus, trailers)) } } @@ -230,7 +254,7 @@ private fun CoroutineScope.serverCallListenerImpl( onReady = { ready.onReady() }, - onComplete = {} + onComplete = { } ) } @@ -242,4 +266,41 @@ private class ServerCallListenerState { var isReceiving = true } -private val unitKType = typeOf() +private val unitKType = typeOf() + + +private class ServerCallScopeImpl( + override val method: MethodDescriptor, + val interceptors: List, + val implementation: (Flow) -> Flow, + override val requestHeaders: GrpcMetadata, + val serverCall: ServerCall, + override val context: GrpcContext, +) : ServerCallScope { + + override val responseHeaders: GrpcMetadata = GrpcMetadata() + override val responseTrailers: GrpcMetadata = GrpcMetadata() + + // keeps track of already processed interceptors + var interceptorIndex = 0 + val onCloseFuture = CallbackFuture>() + + override fun onClose(block: (Status, GrpcMetadata) -> Unit) { + onCloseFuture.onComplete { block(it.first, it.second) } + } + + override fun close(status: Status, trailers: GrpcMetadata): Nothing { + // this will be cached by the rpcImpl() runCatching{} and turns it into a close() + throw StatusException(status, trailers) + } + + override fun proceed(request: Flow): Flow { + return if (interceptorIndex < interceptors.size) { + with(interceptors[interceptorIndex++]) { + intercept(request) + } + } else { + implementation(request) + } + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/BaseGrpcServiceTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/BaseGrpcServiceTest.kt index 08ed33cd4..36ac0815b 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/BaseGrpcServiceTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/BaseGrpcServiceTest.kt @@ -30,17 +30,19 @@ abstract class BaseGrpcServiceTest { ) = runTest { val server = GrpcServer( port = PORT, - messageCodecResolver = resolver, parentContext = coroutineContext, - builder = { + ) { + messageCodecResolver = resolver + services { registerService(kClass) { impl } } - ) + } server.start() - val client = GrpcClient("localhost", PORT, messageCodecResolver = resolver) { - usePlaintext() + val client = GrpcClient("localhost", PORT) { + messageCodecResolver = resolver + credentials = plaintext() } val service = client.withService(kClass) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt index 4b2e08c4c..622ce833d 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CoreClientTest.kt @@ -8,13 +8,14 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withTimeout +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.GrpcServer -import kotlinx.rpc.grpc.GrpcTrailers import kotlinx.rpc.grpc.ManagedChannel import kotlinx.rpc.grpc.ManagedChannelBuilder import kotlinx.rpc.grpc.Status import kotlinx.rpc.grpc.StatusCode import kotlinx.rpc.grpc.buildChannel +import kotlinx.rpc.grpc.createInsecureClientCredentials import kotlinx.rpc.grpc.internal.ClientCall import kotlinx.rpc.grpc.internal.GrpcDefaultCallOptions import kotlinx.rpc.grpc.internal.MethodDescriptor @@ -52,9 +53,10 @@ class GrpcCoreClientTest { private fun ManagedChannel.newHelloCall(fullName: String = "kotlinx.rpc.grpc.test.GreeterService/SayHello"): ClientCall = platformApi.newCall(descriptorFor(fullName), GrpcDefaultCallOptions) - private fun createChannel(): ManagedChannel = ManagedChannelBuilder("localhost:$PORT") - .usePlaintext() - .buildChannel() + private fun createChannel(): ManagedChannel = ManagedChannelBuilder( + target = "localhost:$PORT", + credentials = createInsecureClientCredentials() + ).buildChannel() private fun helloReq(timeout: UInt = 0u): HelloRequest = HelloRequest { @@ -84,7 +86,7 @@ class GrpcCoreClientTest { onClose = { status, _ -> statusDeferred.complete(status) } ) - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) call.sendMessage(req) call.halfClose() call.request(1) @@ -108,8 +110,8 @@ class GrpcCoreClientTest { val listener = createClientCallListener( onClose = { status, _ -> statusDeferred.complete(status) } ) - call.start(listener, GrpcTrailers()) - assertFailsWith { call.start(listener, GrpcTrailers()) } + call.start(listener, GrpcMetadata()) + assertFailsWith { call.start(listener, GrpcMetadata()) } // cancel to finish the call quickly call.cancel("Double start test", null) runBlocking { withTimeout(5000) { statusDeferred.await() } } @@ -125,7 +127,7 @@ class GrpcCoreClientTest { val listener = createClientCallListener( onClose = { status, _ -> statusDeferred.complete(status) } ) - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) call.halfClose() assertFailsWith { call.sendMessage(req) } // Ensure call completes @@ -142,7 +144,7 @@ class GrpcCoreClientTest { val listener = createClientCallListener( onClose = { status, _ -> statusDeferred.complete(status) } ) - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) assertFails { call.request(-1) } call.cancel("cleanup", null) runBlocking { withTimeout(5000) { statusDeferred.await() } } @@ -157,7 +159,7 @@ class GrpcCoreClientTest { val listener = createClientCallListener( onClose = { status, _ -> statusDeferred.complete(status) } ) - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) call.cancel("user cancel", null) runBlocking { withTimeout(10000) { @@ -177,7 +179,7 @@ class GrpcCoreClientTest { onClose = { status, _ -> statusDeferred.complete(status) } ) - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) call.sendMessage(helloReq()) call.halfClose() call.request(1) @@ -198,7 +200,7 @@ class GrpcCoreClientTest { val listener = createClientCallListener() assertFailsWith { try { - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) call.halfClose() call.sendMessage(helloReq()) } finally { @@ -218,7 +220,7 @@ class GrpcCoreClientTest { channel.shutdown() runBlocking { channel.awaitTermination() } - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) call.sendMessage(helloReq()) call.halfClose() call.request(1) @@ -240,7 +242,7 @@ class GrpcCoreClientTest { onClose = { status, _ -> statusDeferred.complete(status) } ) - call.start(listener, GrpcTrailers()) + call.start(listener, GrpcMetadata()) // set timeout on the server to 1000 ms, to simulate a long-running call call.sendMessage(helloReq(1000u)) call.halfClose() @@ -274,8 +276,11 @@ class GreeterServiceImpl : GreeterService { fun runServer() = runTest { val server = GrpcServer( port = PORT, - builder = { registerService { GreeterServiceImpl() } } - ) + ) { + services { + registerService { GreeterServiceImpl() } + } + } try { server.start() @@ -292,9 +297,9 @@ class GreeterServiceImpl : GreeterService { private fun createClientCallListener( - onHeaders: (headers: GrpcTrailers) -> Unit = {}, + onHeaders: (headers: GrpcMetadata) -> Unit = {}, onMessage: (message: T) -> Unit = {}, - onClose: (status: Status, trailers: GrpcTrailers) -> Unit = { _, _ -> }, + onClose: (status: Status, trailers: GrpcMetadata) -> Unit = { _, _ -> }, onReady: () -> Unit = {}, ) = clientCallListener( onHeaders = onHeaders, diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CustomResolverGrpcServiceTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CustomResolverGrpcServiceTest.kt index 7fcc9af5e..0f2983606 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CustomResolverGrpcServiceTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/CustomResolverGrpcServiceTest.kt @@ -20,24 +20,25 @@ import kotlin.test.Test import kotlin.test.assertContentEquals import kotlin.test.assertEquals -@WithCodec(Message.Companion::class) -class Message(val value: String) { - companion object : SourcedMessageCodec { - override fun encodeToSource(value: Message): Source { +@WithCodec(CustomResolverMessage.Companion::class) +class CustomResolverMessage(val value: String) { + companion object Companion : SourcedMessageCodec { + override fun encodeToSource(value: CustomResolverMessage): Source { return Buffer().apply { writeString(value.value) } } - override fun decodeFromSource(stream: Source): Message { - return Message(stream.readString()) + override fun decodeFromSource(stream: Source): CustomResolverMessage { + return CustomResolverMessage(stream.readString()) } } } + @Grpc interface GrpcService { suspend fun plainString(value: String): String - suspend fun message(value: Message): Message + suspend fun message(value: CustomResolverMessage): CustomResolverMessage suspend fun krpc173() @@ -53,8 +54,8 @@ class GrpcServiceImpl : GrpcService { return "$value $value" } - override suspend fun message(value: Message): Message { - return Message("${value.value} ${value.value}") + override suspend fun message(value: CustomResolverMessage): CustomResolverMessage { + return CustomResolverMessage("${value.value} ${value.value}") } override suspend fun krpc173() { @@ -88,7 +89,7 @@ class CustomResolverGrpcServiceTest : BaseGrpcServiceTest() { resolver = simpleResolver, impl = GrpcServiceImpl(), ) { service -> - assertEquals("test test", service.message(Message("test")).value) + assertEquals("test test", service.message(CustomResolverMessage("test")).value) } @Test diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientServerTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientServerTest.kt index 2702d32a2..d86dfd4a2 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientServerTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientServerTest.kt @@ -14,12 +14,10 @@ import kotlinx.io.Buffer import kotlinx.io.Source import kotlinx.io.readString import kotlinx.io.writeString -import kotlinx.rpc.grpc.ManagedChannelBuilder +import kotlinx.rpc.grpc.GrpcClient import kotlinx.rpc.grpc.Server import kotlinx.rpc.grpc.ServerBuilder -import kotlinx.rpc.grpc.buildChannel import kotlinx.rpc.grpc.codec.SourcedMessageCodec -import kotlinx.rpc.grpc.internal.GrpcChannel import kotlinx.rpc.grpc.internal.MethodDescriptor import kotlinx.rpc.grpc.internal.MethodType import kotlinx.rpc.grpc.internal.ServerMethodDefinition @@ -46,10 +44,10 @@ class RawClientServerTest { methodName = "unary", type = MethodType.UNARY, methodDefinition = { descriptor -> - unaryServerMethodDefinition(descriptor, typeOf()) { it + it } + unaryServerMethodDefinition(descriptor, typeOf(), emptyList()) { it + it } }, - ) { channel, descriptor -> - val response = unaryRpc(channel, descriptor, "Hello") + ) { client, descriptor -> + val response = client.unaryRpc(descriptor, "Hello") assertEquals("HelloHello", response) } @@ -59,12 +57,12 @@ class RawClientServerTest { methodName = "serverStreaming", type = MethodType.SERVER_STREAMING, methodDefinition = { descriptor -> - serverStreamingServerMethodDefinition(descriptor, typeOf()) { + serverStreamingServerMethodDefinition(descriptor, typeOf(), emptyList()) { flowOf(it, it) } } - ) { channel, descriptor -> - val response = serverStreamingRpc(channel, descriptor, "Hello") + ) { client, descriptor -> + val response = client.serverStreamingRpc(descriptor, "Hello") assertEquals(listOf("Hello", "Hello"), response.toList()) } @@ -74,44 +72,46 @@ class RawClientServerTest { methodName = "clientStreaming", type = MethodType.CLIENT_STREAMING, methodDefinition = { descriptor -> - clientStreamingServerMethodDefinition(descriptor, typeOf()) { + clientStreamingServerMethodDefinition(descriptor, typeOf(), emptyList()) { it.toList().joinToString(separator = "") } } - ) { channel, descriptor -> - val response = clientStreamingRpc(channel, descriptor, flowOf("Hello", "World")) + ) { client, descriptor -> + val response = client.clientStreamingRpc(descriptor, flowOf("Hello", "World")) assertEquals("HelloWorld", response) } @Test - fun bidirectionalStreamingCall() = runTest( - methodName = "bidirectionalStreaming", - type = MethodType.BIDI_STREAMING, - methodDefinition = { descriptor -> - bidiStreamingServerMethodDefinition(descriptor, typeOf()) { - it.map { str -> str + str } + fun bidirectionalStreamingCall() { + runTest( + methodName = "bidirectionalStreaming", + type = MethodType.BIDI_STREAMING, + methodDefinition = { descriptor -> + bidiStreamingServerMethodDefinition(descriptor, typeOf(), emptyList()) { + it.map { str -> str + str } + } } - } - ) { channel, descriptor -> - val response = bidirectionalStreamingRpc(channel, descriptor, flowOf("Hello", "World")) - .toList() + ) { client, descriptor -> + val response = client.bidirectionalStreamingRpc(descriptor, flowOf("Hello", "World")) + .toList() - assertEquals(listOf("HelloHello", "WorldWorld"), response) + assertEquals(listOf("HelloHello", "WorldWorld"), response) + } } private fun runTest( methodName: String, type: MethodType, methodDefinition: CoroutineScope.(MethodDescriptor) -> ServerMethodDefinition, - block: suspend (GrpcChannel, MethodDescriptor) -> Unit, + block: suspend (GrpcClient, MethodDescriptor) -> Unit, ) = kotlinx.coroutines.test.runTest { val serverJob = Job() val serverScope = CoroutineScope(serverJob) - val clientChannel = ManagedChannelBuilder("localhost", PORT).apply { - usePlaintext() - }.buildChannel() + val client = GrpcClient("localhost", PORT) { + credentials = plaintext() + } val descriptor = methodDescriptor( fullMethodName = "${SERVICE_NAME}/$methodName", @@ -139,11 +139,11 @@ class RawClientServerTest { val server = Server(builder) server.start() - block(clientChannel.platformApi, descriptor) + block(client, descriptor) serverJob.cancelAndJoin() - clientChannel.shutdown() - clientChannel.awaitTermination() + client.shutdown() + client.awaitTermination() server.shutdown() server.awaitTermination() } diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientTest.kt index f50105f5a..bde3d3741 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/RawClientTest.kt @@ -9,10 +9,15 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest +import kotlinx.rpc.grpc.GrpcClient import kotlinx.rpc.grpc.GrpcServer -import kotlinx.rpc.grpc.ManagedChannelBuilder -import kotlinx.rpc.grpc.buildChannel -import kotlinx.rpc.grpc.internal.* +import kotlinx.rpc.grpc.internal.MethodDescriptor +import kotlinx.rpc.grpc.internal.MethodType +import kotlinx.rpc.grpc.internal.bidirectionalStreamingRpc +import kotlinx.rpc.grpc.internal.clientStreamingRpc +import kotlinx.rpc.grpc.internal.methodDescriptor +import kotlinx.rpc.grpc.internal.serverStreamingRpc +import kotlinx.rpc.grpc.internal.unaryRpc import kotlinx.rpc.registerService import kotlin.test.Test import kotlin.test.assertEquals @@ -32,8 +37,8 @@ class RawClientTest { fun unaryEchoTest() = runTest( methodName = "UnaryEcho", type = MethodType.UNARY, - ) { channel, descriptor -> - val response = unaryRpc(channel, descriptor, EchoRequest { message = "Eccchhooo" }) + ) { client, descriptor -> + val response = client.unaryRpc(descriptor, EchoRequest { message = "Eccchhooo" }) assertEquals("Eccchhooo", response.message) } @@ -41,8 +46,8 @@ class RawClientTest { fun serverStreamingEchoTest() = runTest( methodName = "ServerStreamingEcho", type = MethodType.SERVER_STREAMING, - ) { channel, descriptor -> - val response = serverStreamingRpc(channel, descriptor, EchoRequest { message = "Eccchhooo" }) + ) { client, descriptor -> + val response = client.serverStreamingRpc(descriptor, EchoRequest { message = "Eccchhooo" }) var i = 0 response.collect { println("Received: ${i++}") @@ -54,10 +59,9 @@ class RawClientTest { fun clientStreamingEchoTest() = runTest( methodName = "ClientStreamingEcho", type = MethodType.CLIENT_STREAMING, - ) { channel, descriptor -> - val response = clientStreamingRpc(channel, descriptor, flow { + ) { client, descriptor -> + val response = client.clientStreamingRpc(descriptor, flow { repeat(5) { - delay(100) println("Sending: ${it + 1}") emit(EchoRequest { message = "Eccchhooo" }) } @@ -70,8 +74,8 @@ class RawClientTest { fun bidirectionalStreamingEchoTest() = runTest( methodName = "BidirectionalStreamingEcho", type = MethodType.BIDI_STREAMING, - ) { channel, descriptor -> - val response = bidirectionalStreamingRpc(channel, descriptor, flow { + ) { client, descriptor -> + val response = client.bidirectionalStreamingRpc(descriptor, flow { repeat(5) { emit(EchoRequest { message = "Eccchhooo" }) } @@ -88,11 +92,11 @@ class RawClientTest { fun runTest( methodName: String, type: MethodType, - block: suspend (GrpcChannel, MethodDescriptor) -> Unit, + block: suspend (GrpcClient, MethodDescriptor) -> Unit, ) = runTest { - val channel = ManagedChannelBuilder("localhost:50051") - .usePlaintext() - .buildChannel() + val client = GrpcClient("localhost:50051") { + credentials = plaintext() + } val methodDescriptor = methodDescriptor( fullMethodName = "kotlinx.rpc.grpc.test.EchoService/$methodName", @@ -106,10 +110,10 @@ class RawClientTest { ) try { - block(channel.platformApi, methodDescriptor) + block(client, methodDescriptor) } finally { - channel.shutdown() - channel.awaitTermination() + client.shutdown() + client.awaitTermination() } } } @@ -152,8 +156,7 @@ class EchoServiceImpl : EchoService { fun runServer() = runTest(timeout = Duration.INFINITE) { val server = GrpcServer( port = PORT, - builder = { registerService { EchoServiceImpl() } } - ) + ) { services { registerService { EchoServiceImpl() } } } try { server.start() diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ClientInterceptorTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ClientInterceptorTest.kt new file mode 100644 index 000000000..3ccf86132 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ClientInterceptorTest.kt @@ -0,0 +1,353 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.test.proto + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import kotlinx.rpc.RpcServer +import kotlinx.rpc.grpc.ClientCallScope +import kotlinx.rpc.grpc.ClientInterceptor +import kotlinx.rpc.grpc.GrpcClient +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.StatusException +import kotlinx.rpc.grpc.statusCode +import kotlinx.rpc.grpc.test.EchoRequest +import kotlinx.rpc.grpc.test.EchoResponse +import kotlinx.rpc.grpc.test.EchoService +import kotlinx.rpc.grpc.test.EchoServiceImpl +import kotlinx.rpc.grpc.test.invoke +import kotlinx.rpc.registerService +import kotlinx.rpc.withService +import kotlin.test.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs +import kotlin.test.assertTrue + +class ClientInterceptorTest : GrpcProtoTest() { + + override fun RpcServer.registerServices() { + registerService { EchoServiceImpl() } + } + + @Test + fun `throw during intercept - should fail with thrown exception`() { + val error = assertFailsWith { + val interceptor = interceptor { + throw IllegalStateException("Failing in interceptor") + } + runGrpcTest(clientInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(error.message, "Failing in interceptor") + } + + @Test + fun `throw during onHeader - should fail with status exception containing the thrown exception`() { + val error = assertFailsWith { + val interceptor = interceptor { + onHeaders { + throw IllegalStateException("Failing in onHeader") + } + proceed(it) + } + runGrpcTest(clientInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertIs(error.cause) + assertEquals("Failing in onHeader", error.cause?.message) + } + + @Test + fun `throw during onClose - should fail with status exception containing the thrown exception`() { + val error = assertFailsWith { + val interceptor = interceptor { + onClose { _, _ -> + throw IllegalStateException("Failing in onClose") + } + proceed(it) + } + runGrpcTest(clientInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertIs(error.cause) + assertEquals("Failing in onClose", error.cause?.message) + } + + @Test + fun `cancel in intercept - should fail with cancellation`() { + val error = assertFailsWith { + val interceptor = interceptor { + cancel("Canceling in interceptor", IllegalStateException("Cancellation cause")) + } + runGrpcTest(clientInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertContains(error.message!!, "Canceling in interceptor") + assertIs(error.cause) + assertEquals("Cancellation cause", error.cause?.message) + } + + @Test + fun `cancel in request flow - should fail with cancellation`() { + val error = assertFailsWith { + val interceptor = interceptor { + proceed(it.map { + val msg = it as EchoRequest + if (msg.message == "Echo-3") { + cancel("Canceling in request flow", IllegalStateException("Cancellation cause")) + } + it + }) + } + runGrpcTest(clientInterceptors = interceptor, test = ::bidiStream) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertContains(error.message!!, "Canceling in request flow") + assertIs(error.cause) + assertEquals("Cancellation cause", error.cause?.message) + } + + @Test + fun `cancel in response flow - should fail with cancellation`() { + val error = assertFailsWith { + val interceptor = interceptor { + flow { + proceed(it).collect { resp -> + val msg = resp as EchoResponse + if (msg.message == "Echo-3") { + cancel("Canceling in response flow", IllegalStateException("Cancellation cause")) + } + emit(resp) + } + } + } + runGrpcTest(clientInterceptors = interceptor, test = ::bidiStream) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertContains(error.message!!, "Canceling in response flow") + assertIs(error.cause) + assertEquals("Cancellation cause", error.cause?.message) + } + + @Test + fun `cancel onHeaders - should fail with cancellation`() { + val error = assertFailsWith { + val interceptor = interceptor { + this.onHeaders { + cancel("Canceling in headers", IllegalStateException("Cancellation cause")) + } + proceed(it) + } + runGrpcTest(clientInterceptors = interceptor, test = ::bidiStream) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertContains(error.message!!, "Canceling in headers") + assertIs(error.cause) + assertEquals("Cancellation cause", error.cause?.message) + } + + @Test + fun `cancel onClose - should fail with cancellation`() { + val error = assertFailsWith { + val interceptor = interceptor { + this.onClose { _, _ -> + cancel("Canceling in onClose", IllegalStateException("Cancellation cause")) + } + proceed(it) + } + runGrpcTest(clientInterceptors = interceptor, test = ::bidiStream) + } + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertContains(error.message!!, "Canceling in onClose") + assertIs(error.cause) + assertEquals("Cancellation cause", error.cause?.message) + } + + @Test + fun `cancel in two interceptors - should fail with cancellation`() { + val error = assertFailsWith { + val interceptor1 = interceptor { + onClose { _, _ -> cancel("[1] Canceling in onClose", IllegalStateException("Cancellation cause")) } + proceed(it) + } + val interceptor2 = interceptor { + onClose { _, _ -> cancel("[2] Canceling in onClose", IllegalStateException("Cancellation cause")) } + proceed(it) + } + runGrpcTest(clientInterceptors = interceptor1 + interceptor2, test = ::unaryCall) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertContains(error.message!!, "[1] Canceling in onClose") + assertIs(error.cause) + assertEquals("Cancellation cause", error.cause?.message) + } + + @Test + fun `cancel in two interceptors withing response stream - should fail with cancellation`() { + val error = assertFailsWith { + val interceptor1 = interceptor { + proceed(it).map { + val msg = it as EchoResponse + if (msg.message == "Echo-3") { + cancel("[1] Canceling in response flow", IllegalStateException("Cancellation cause")) + } + it + } + } + val interceptor2 = interceptor { + proceed(it).map { + val msg = it as EchoResponse + // this is cancelled before the first one + if (msg.message == "Echo-2") { + cancel("[2] Canceling in response flow", IllegalStateException("Cancellation cause")) + } + it + } + } + runGrpcTest(clientInterceptors = interceptor1 + interceptor2, test = ::bidiStream) + } + + assertEquals(StatusCode.CANCELLED, error.getStatus().statusCode) + assertContains(error.message!!, "[2] Canceling in response flow") + assertIs(error.cause) + assertEquals("Cancellation cause", error.cause?.message) + } + + @Test + fun `modify request message - should return modified message`() { + val interceptor = interceptor { + val modified = it.map { EchoRequest { message = "Modified" } } + proceed(modified) + } + runGrpcTest(clientInterceptors = interceptor) { + val service = it.withService() + val response = service.UnaryEcho(EchoRequest { message = "Hello" }) + assertEquals("Modified", response.message) + } + } + + @Test + fun `modify response message - should return modified message`() { + val interceptor = interceptor { + proceed(it).map { EchoResponse { message = "Modified" } } + } + runGrpcTest(clientInterceptors = interceptor) { + val service = it.withService() + val response = service.UnaryEcho(EchoRequest { message = "Hello" }) + assertEquals("Modified", response.message) + } + } + + @Test + fun `append a response message once closed`() { + val interceptor = interceptor { + channelFlow { + proceed(it).collect { + trySend(it) + } + onClose { status, _ -> + trySend(EchoResponse { message = "Appended-after-close-with-${status.statusCode}" }) + } + } + } + + runGrpcTest( + clientInterceptors = interceptor + ) { client -> + val svc = client.withService() + val responses = svc.BidirectionalStreamingEcho(flow { + repeat(5) { + emit(EchoRequest { message = "Eccchhooo" }) + } + }).toList() + assertEquals(6, responses.size) + assertTrue(responses.any { it.message == "Appended-after-close-with-OK" }) + } + } + + @Test + fun `test exact order of interceptor execution`() { + val order = mutableListOf() + val interceptor1 = interceptor { request -> + order.add(1) + flow { + order.add(2) + val req = request.map { order.add(5); it } + proceed(req).collect { + order.add(8) + emit(it) + } + order.add(10) + } + } + val interceptor2 = interceptor { request -> + order.add(3) + flow { + order.add(4) + val req = request.map { order.add(6); it } + proceed(req).collect { + order.add(7) + emit(it) + } + order.add(9) + } + } + + val both = interceptor1 + interceptor2 + runGrpcTest(clientInterceptors = both) { unaryCall(it) } + + assertEquals( + listOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + order + ) + } + + private suspend fun unaryCall(grpcClient: GrpcClient) { + val service = grpcClient.withService() + val response = service.UnaryEcho(EchoRequest { message = "Hello" }) + assertEquals("Hello", response.message) + } + + private suspend fun bidiStream(grpcClient: GrpcClient, count: Int = 5) { + val service = grpcClient.withService() + val responses = service.BidirectionalStreamingEcho(flow { + repeat(count) { + emit(EchoRequest { message = "Echo-$it" }) + } + }).toList() + assertEquals(count, responses.size) + repeat(count) { + assertEquals("Echo-$it", responses[it].message) + } + } + +} + +private fun interceptor( + block: ClientCallScope.(Flow) -> Flow, +): List { + return listOf(object : ClientInterceptor { + @Suppress("UNCHECKED_CAST") + override fun ClientCallScope.intercept( + request: Flow, + ): Flow { + with(this as ClientCallScope) { + return block(request as Flow) as Flow + } + } + }) +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcProtoTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcProtoTest.kt index 92050dd19..367a71b40 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcProtoTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/GrpcProtoTest.kt @@ -9,9 +9,11 @@ import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.test.runTest import kotlinx.rpc.RpcServer import kotlinx.rpc.grpc.ClientCredentials +import kotlinx.rpc.grpc.ClientInterceptor import kotlinx.rpc.grpc.GrpcClient import kotlinx.rpc.grpc.GrpcServer import kotlinx.rpc.grpc.ServerCredentials +import kotlinx.rpc.grpc.ServerInterceptor abstract class GrpcProtoTest { private val serverMutex = Mutex() @@ -22,22 +24,24 @@ abstract class GrpcProtoTest { serverCreds: ServerCredentials? = null, clientCreds: ClientCredentials? = null, overrideAuthority: String? = null, + clientInterceptors: List = emptyList(), + serverInterceptors: List = emptyList(), test: suspend (GrpcClient) -> Unit, ) = runTest { serverMutex.withLock { - val grpcClient = GrpcClient("localhost", PORT, credentials = clientCreds) { - if (overrideAuthority != null) overrideAuthority(overrideAuthority) - if (clientCreds == null) { - usePlaintext() - } + val grpcClient = GrpcClient("localhost", PORT) { + credentials = clientCreds ?: plaintext() + if (overrideAuthority != null) this.overrideAuthority = overrideAuthority + clientInterceptors.forEach { intercept(it) } } val grpcServer = GrpcServer( PORT, - credentials = serverCreds, - builder = { - registerServices() - }) + ) { + credentials = serverCreds + serverInterceptors.forEach { intercept(it) } + services { registerServices() } + } grpcServer.start() try { diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/JavaPackageOptionTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/JavaPackageOptionTest.kt index 1c7dd3017..c0775bbdf 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/JavaPackageOptionTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/JavaPackageOptionTest.kt @@ -8,8 +8,6 @@ import com.google.protobuf.kotlin.Empty import com.google.protobuf.kotlin.EmptyInternal import com.google.protobuf.kotlin.invoke import kotlinx.rpc.RpcServer -import kotlinx.rpc.grpc.ManagedChannelBuilder -import kotlinx.rpc.grpc.buildChannel import kotlinx.rpc.grpc.internal.MethodType import kotlinx.rpc.grpc.internal.methodDescriptor import kotlinx.rpc.grpc.internal.unaryRpc @@ -35,11 +33,7 @@ class JavaPackageOptionTest : GrpcProtoTest() { * Tests that the generated service descriptor uses the `package` name. */ @Test - fun testJavaPackageOptionRaw() = runGrpcTest { _ -> - val channel = ManagedChannelBuilder("localhost", PORT) - .usePlaintext() - .buildChannel() - + fun testJavaPackageOptionRaw() = runGrpcTest { client -> val descriptor = methodDescriptor( fullMethodName = "protopackage.TheService/TheMethod", requestCodec = EmptyInternal.CODEC, @@ -51,7 +45,7 @@ class JavaPackageOptionTest : GrpcProtoTest() { sampledToLocalTracing = true, ) - unaryRpc(channel.platformApi, descriptor, Empty {}) + client.unaryRpc(descriptor, Empty {}) // just reach this without an error } diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ServerInterceptorTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ServerInterceptorTest.kt new file mode 100644 index 000000000..287d40e82 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/test/proto/ServerInterceptorTest.kt @@ -0,0 +1,290 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.test.proto + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import kotlinx.rpc.RpcServer +import kotlinx.rpc.grpc.GrpcClient +import kotlinx.rpc.grpc.GrpcMetadata +import kotlinx.rpc.grpc.ServerCallScope +import kotlinx.rpc.grpc.ServerInterceptor +import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.StatusException +import kotlinx.rpc.grpc.statusCode +import kotlinx.rpc.grpc.test.EchoRequest +import kotlinx.rpc.grpc.test.EchoResponse +import kotlinx.rpc.grpc.test.EchoService +import kotlinx.rpc.grpc.test.EchoServiceImpl +import kotlinx.rpc.grpc.test.invoke +import kotlinx.rpc.registerService +import kotlinx.rpc.withService +import kotlin.test.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs + +class ServerInterceptorTest : GrpcProtoTest() { + + override fun RpcServer.registerServices() { + registerService { EchoServiceImpl() } + } + + @Test + fun `throw during onClosing - should fail propagate the exception to the server root`() { + val error = assertFailsWith { + val interceptor = interceptor { + onClose { _, _ -> throw IllegalStateException("Illegal failing in onClose") } + proceed(it) + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + } + + assertContains(error.message!!, "Illegal failing in onClose") + // check that the error is indeed causing a server crash + assertContains(error.stackTraceToString(), "suspendServerCall") + } + + @Test + fun `throw during intercept - should fail with unknown status on client`() { + var cause: Throwable? = null + val error = assertFailsWith { + val interceptor = interceptor { + onClose { status, _ -> cause = status.getCause() } + // this exception is not propagated to the client (only as UNKNOWN status code) + throw IllegalStateException("Failing in interceptor") + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.UNKNOWN, error.getStatus().statusCode) + assertIs(cause) + assertEquals("Failing in interceptor", cause?.message) + } + + + @Test + fun `close during intercept - should fail with correct status on client`() { + val error = assertFailsWith { + val interceptor = interceptor { + close(Status(StatusCode.UNAUTHENTICATED, "Close in interceptor"), GrpcMetadata()) + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode) + assertContains(error.getStatus().getDescription()!!, "Close in interceptor") + } + + @Test + fun `close during request flow - should fail with correct status on client`() { + val error = assertFailsWith { + val interceptor = interceptor { + proceed( + it.map { + close(Status(StatusCode.UNAUTHENTICATED, "Close in request flow"), GrpcMetadata()) + } + ) + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode) + assertContains(error.message!!, "Close in request flow") + } + + @Test + fun `close during response flow - should fail with correct status on client`() { + val error = assertFailsWith { + val interceptor = interceptor { + proceed(it).map { + close(Status(StatusCode.UNAUTHENTICATED, "Close in response flow"), GrpcMetadata()) + } + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode) + assertContains(error.message!!, "Close in response flow") + } + + @Test + fun `close during onClose - should fail with correct status on client`() { + val error = assertFailsWith { + val interceptor = interceptor { + onClose { _, _ -> close(Status(StatusCode.UNAUTHENTICATED, "Close in onClose"), GrpcMetadata()) } + proceed(it) + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + } + + assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode) + assertContains(error.message!!, "Close in onClose") + } + + @Test + fun `close in two interceptors - should fail with correct status on client`() { + val error = assertFailsWith { + val interceptor1 = interceptor { + onClose { _, _ -> close(Status(StatusCode.UNAUTHENTICATED, "[1] Close in onClose"), GrpcMetadata()) } + proceed(it) + } + val interceptor2 = interceptor { + onClose { _, _ -> close(Status(StatusCode.UNAUTHENTICATED, "[2] Close in onClose"), GrpcMetadata()) } + proceed(it) + } + runGrpcTest(serverInterceptors = interceptor1 + interceptor2, test = ::unaryCall) + } + + assertEquals(StatusCode.UNAUTHENTICATED, error.getStatus().statusCode) + assertContains(error.message!!, "[1] Close in onClose") + } + + @Test + fun `dont proceed and return custom message - should succeed on client`() { + val interceptor = interceptor { + flowOf(EchoResponse { message = "Custom message" }) + } + runGrpcTest(serverInterceptors = interceptor) { + val service = it.withService() + val response = service.UnaryEcho(EchoRequest { message = "Hello" }) + assertEquals("Custom message", response.message) + } + } + + @Test + fun `manipulate request - should succeed on client`() { + val interceptor = interceptor { + proceed(it.map { EchoRequest { message = "Modified" } }) + } + runGrpcTest(serverInterceptors = interceptor) { + val service = it.withService() + val response = service.UnaryEcho(EchoRequest { message = "Hello" }) + assertEquals("Modified", response.message) + } + } + + @Test + fun `manipulate response - should succeed on client`() { + val interceptor = interceptor { + proceed(it).map { EchoResponse { message = "Modified" } } + } + runGrpcTest(serverInterceptors = interceptor) { + val service = it.withService() + val response = service.UnaryEcho(EchoRequest { message = "Hello" }) + assertEquals("Modified", response.message) + } + } + + @Test + fun `proceedFlow - should succeed on client`() { + val interceptor = interceptor { + flow { + proceedUnmodified(it) + } + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + } + + @Test + fun `test exact order of interceptor execution`() { + val order = mutableListOf() + val interceptor1 = interceptor { request -> + flow { + order.add(1) + var i1 = 0 + val ids = listOf(3, 7) + val req = request.map { order.add(ids[i1++]); it } + + var i2 = 0 + val respIds = listOf(6, 10) + proceed(req).collect { + order.add(respIds[i2++]) + emit(it) + } + + order.add(12) + } + } + + val interceptor2 = interceptor { request -> + flow { + order.add(2) + var i1 = 0 + val reqIds = listOf(4, 8) + val req = request.map { order.add(reqIds[i1++]); it } + + var i2 = 0 + val respIds = listOf(5, 9) + proceed(req).collect { + order.add(respIds[i2++]) + emit(it) + } + + order.add(11) + } + } + val both = interceptor1 + interceptor2 + + runGrpcTest(serverInterceptors = both) { bidiStream(it, 2) } + + assertEquals( + listOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), + order + ) + } + + @Test + fun `method descriptor - full method name is exposed`() { + var methodName: String? = null + val interceptor = interceptor { + methodName = method.getFullMethodName() + proceed(it) + } + runGrpcTest(serverInterceptors = interceptor, test = ::unaryCall) + assertContains(methodName!!, "EchoService/UnaryEcho") + } + + private suspend fun unaryCall(grpcClient: GrpcClient) { + val service = grpcClient.withService() + val response = service.UnaryEcho(EchoRequest { message = "Hello" }) + assertEquals("Hello", response.message) + } + + private suspend fun bidiStream(grpcClient: GrpcClient, count: Int = 5) { + val service = grpcClient.withService() + val responses = service.BidirectionalStreamingEcho(flow { + repeat(count) { + emit(EchoRequest { message = "Echo-$it" }) + } + }).toList() + assertEquals(count, responses.size) + repeat(count) { + assertEquals("Echo-$it", responses[it].message) + } + } + +} + + +private fun interceptor( + block: ServerCallScope.(Flow) -> Flow, +): List { + return listOf(object : ServerInterceptor { + @Suppress("UNCHECKED_CAST") + override fun ServerCallScope.intercept( + request: Flow, + ): Flow { + with(this as ServerCallScope) { + return block(request as Flow) as Flow + } + } + }) +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.jvm.kt similarity index 79% rename from grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.jvm.kt rename to grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.jvm.kt index 090e3c718..fbb286a8c 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.jvm.kt @@ -7,4 +7,4 @@ package kotlinx.rpc.grpc import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi -public actual typealias GrpcTrailers = io.grpc.Metadata +public actual typealias GrpcMetadata = io.grpc.Metadata diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/credentials.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/credentials.jvm.kt index 3ba8e4ce4..12d6213cd 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/credentials.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/credentials.jvm.kt @@ -8,14 +8,23 @@ public actual typealias ClientCredentials = io.grpc.ChannelCredentials public actual typealias ServerCredentials = io.grpc.ServerCredentials -// we need a wrapper for InsecureChannelCredentials as our constructor would conflict with the private -// java constructor. public actual typealias InsecureClientCredentials = io.grpc.InsecureChannelCredentials public actual typealias InsecureServerCredentials = io.grpc.InsecureServerCredentials public actual typealias TlsClientCredentials = io.grpc.TlsChannelCredentials public actual typealias TlsServerCredentials = io.grpc.TlsServerCredentials + +// we need a wrapper for InsecureChannelCredentials as our constructor would conflict with the private +// java constructor. +internal actual fun createInsecureClientCredentials(): ClientCredentials { + return InsecureClientCredentials.create() +} + +internal actual fun createInsecureServerCredentials(): ServerCredentials { + return InsecureServerCredentials.create() +} + internal actual fun TlsClientCredentialsBuilder(): TlsClientCredentialsBuilder = JvmTlsCLientCredentialBuilder() internal actual fun TlsServerCredentialsBuilder( certChain: String, @@ -80,6 +89,4 @@ private fun TlsClientAuth.toJava(): io.grpc.TlsServerCredentials.ClientAuth = wh TlsClientAuth.NONE -> io.grpc.TlsServerCredentials.ClientAuth.NONE TlsClientAuth.OPTIONAL -> io.grpc.TlsServerCredentials.ClientAuth.OPTIONAL TlsClientAuth.REQUIRE -> io.grpc.TlsServerCredentials.ClientAuth.REQUIRE -} - - +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.jvm.kt index bd3dde915..010091179 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.jvm.kt @@ -6,7 +6,7 @@ package kotlinx.rpc.grpc.internal import io.grpc.Metadata import io.grpc.ClientCall -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.Status import kotlinx.rpc.internal.utils.InternalRpcApi @@ -14,9 +14,9 @@ internal actual typealias ClientCall = ClientCall clientCallListener( - crossinline onHeaders: (headers: GrpcTrailers) -> Unit, + crossinline onHeaders: (headers: GrpcMetadata) -> Unit, crossinline onMessage: (message: Message) -> Unit, - crossinline onClose: (status: Status, trailers: GrpcTrailers) -> Unit, + crossinline onClose: (status: Status, trailers: GrpcMetadata) -> Unit, crossinline onReady: () -> Unit, ): ClientCall.Listener { return object : ClientCall.Listener() { diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt index 4309b1c6f..1a65baf6e 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.kt @@ -8,12 +8,12 @@ import io.grpc.Context import kotlinx.coroutines.ThreadContextElement import kotlin.coroutines.CoroutineContext -internal actual typealias GrpcContext = Context +public actual typealias GrpcContext = Context internal actual val CurrentGrpcContext: GrpcContext get() = GrpcContext.current() -internal actual class GrpcContextElement(private val grpcContext: GrpcContext) : ThreadContextElement { +internal actual class GrpcContextElement(actual val grpcContext: GrpcContext) : ThreadContextElement { actual companion object Key : CoroutineContext.Key { actual fun current(): GrpcContextElement = GrpcContextElement(CurrentGrpcContext) } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.native.kt similarity index 65% rename from grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.native.kt rename to grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.native.kt index 9cfb2249d..91d154c77 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcTrailers.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/GrpcMetadata.native.kt @@ -5,6 +5,6 @@ package kotlinx.rpc.grpc @Suppress(names = ["RedundantConstructorKeyword"]) -public actual class GrpcTrailers actual constructor() { - public actual fun merge(trailers: GrpcTrailers) {} +public actual class GrpcMetadata actual constructor() { + public actual fun merge(trailers: GrpcMetadata) {} } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.native.kt index 85a5439e6..4f8a27de7 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/ManagedChannel.native.kt @@ -19,10 +19,6 @@ public actual abstract class ManagedChannelPlatform : GrpcChannel() * Builder class for [ManagedChannel]. */ public actual abstract class ManagedChannelBuilder> { - public actual open fun usePlaintext(): T { - error("Builder does not support usePlaintext()") - } - public actual abstract fun overrideAuthority(authority: String): T } @@ -33,11 +29,6 @@ internal class NativeManagedChannelBuilder( private var authority: String? = null - override fun usePlaintext(): NativeManagedChannelBuilder { - credentials = lazy { InsecureChannelCredentials() } - return this - } - override fun overrideAuthority(authority: String): NativeManagedChannelBuilder { this.authority = authority return this diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Server.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Server.native.kt index 476214555..05af2c8cf 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Server.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Server.native.kt @@ -43,7 +43,7 @@ private class NativeServerBuilder( } internal actual fun ServerBuilder(port: Int, credentials: ServerCredentials?): ServerBuilder<*> { - return NativeServerBuilder(port, credentials ?: InsecureServerCredentials()) + return NativeServerBuilder(port, credentials ?: createInsecureServerCredentials()) } internal actual fun Server(builder: ServerBuilder<*>): Server { diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Status.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Status.native.kt index 1f14bb606..b99808810 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Status.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/Status.native.kt @@ -12,6 +12,10 @@ public actual class Status internal constructor( public actual fun getDescription(): String? = description public actual fun getCause(): Throwable? = cause + + override fun toString(): String { + return "Status(description=$description, statusCode=$statusCode, cause=$cause)" + } } public actual val Status.statusCode: StatusCode diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/StatusException.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/StatusException.native.kt index e319178e4..f5fe62187 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/StatusException.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/StatusException.native.kt @@ -6,11 +6,11 @@ package kotlinx.rpc.grpc public actual class StatusException : Exception { private val status: Status - private val trailers: GrpcTrailers? + private val trailers: GrpcMetadata? public actual constructor(status: Status) : this(status, null) - public actual constructor(status: Status, trailers: GrpcTrailers?) : super( + public actual constructor(status: Status, trailers: GrpcMetadata?) : super( "${status.statusCode}: ${status.getDescription()}", status.getCause() ) { @@ -20,16 +20,16 @@ public actual class StatusException : Exception { public actual fun getStatus(): Status = status - public actual fun getTrailers(): GrpcTrailers? = trailers + public actual fun getTrailers(): GrpcMetadata? = trailers } public actual class StatusRuntimeException : RuntimeException { private val status: Status - private val trailers: GrpcTrailers? + private val trailers: GrpcMetadata? public actual constructor(status: Status) : this(status, null) - public actual constructor(status: Status, trailers: GrpcTrailers?) : super( + public actual constructor(status: Status, trailers: GrpcMetadata?) : super( "${status.statusCode}: ${status.getDescription()}", status.getCause() ) { @@ -39,5 +39,5 @@ public actual class StatusRuntimeException : RuntimeException { public actual fun getStatus(): Status = status - public actual fun getTrailers(): GrpcTrailers? = trailers + public actual fun getTrailers(): GrpcMetadata? = trailers } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/credentials.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/credentials.native.kt index b10e7f65b..77f9e3082 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/credentials.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/credentials.native.kt @@ -61,18 +61,17 @@ public actual class TlsClientCredentials internal constructor( raw: CPointer, ) : ClientCredentials(raw) -public actual class TlsServerCredentials( +public actual class TlsServerCredentials internal constructor( raw: CPointer, ) : ServerCredentials(raw) - -public fun InsecureChannelCredentials(): ClientCredentials { +internal actual fun createInsecureClientCredentials(): ClientCredentials { return InsecureClientCredentials( grpc_insecure_credentials_create() ?: error("grpc_insecure_credentials_create() returned null") ) } -public fun InsecureServerCredentials(): ServerCredentials { +internal actual fun createInsecureServerCredentials(): ServerCredentials { return InsecureServerCredentials( grpc_insecure_server_credentials_create() ?: error("grpc_insecure_server_credentials_create() returned null") ) diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.native.kt index f650f9e17..ea8c3ed46 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ClientCall.native.kt @@ -7,7 +7,7 @@ package kotlinx.rpc.grpc.internal import kotlinx.cinterop.ExperimentalForeignApi -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.Status import kotlinx.rpc.internal.utils.InternalRpcApi import kotlin.experimental.ExperimentalNativeApi @@ -16,7 +16,7 @@ import kotlin.experimental.ExperimentalNativeApi public actual abstract class ClientCall { public actual abstract fun start( responseListener: Listener, - headers: GrpcTrailers, + headers: GrpcMetadata, ) public actual abstract fun request(numMessages: Int) @@ -30,10 +30,10 @@ public actual abstract class ClientCall { @InternalRpcApi public actual abstract class Listener { - public actual open fun onHeaders(headers: GrpcTrailers) { + public actual open fun onHeaders(headers: GrpcMetadata) { } - public actual open fun onClose(status: Status, trailers: GrpcTrailers) { + public actual open fun onClose(status: Status, trailers: GrpcMetadata) { } public actual open fun onMessage(message: Message) { @@ -46,13 +46,13 @@ public actual abstract class ClientCall { @InternalRpcApi public actual fun clientCallListener( - onHeaders: (headers: GrpcTrailers) -> Unit, + onHeaders: (headers: GrpcMetadata) -> Unit, onMessage: (message: Message) -> Unit, - onClose: (status: Status, trailers: GrpcTrailers) -> Unit, + onClose: (status: Status, trailers: GrpcMetadata) -> Unit, onReady: () -> Unit, ): ClientCall.Listener { return object : ClientCall.Listener() { - override fun onHeaders(headers: GrpcTrailers) { + override fun onHeaders(headers: GrpcMetadata) { onHeaders(headers) } @@ -60,7 +60,7 @@ public actual fun clientCallListener( onMessage(message) } - override fun onClose(status: Status, trailers: GrpcTrailers) { + override fun onClose(status: Status, trailers: GrpcMetadata) { onClose(status, trailers) } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/CompletionQueue.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/CompletionQueue.kt index 329c45ae1..59b6cbcd3 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/CompletionQueue.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/CompletionQueue.kt @@ -119,6 +119,8 @@ internal class CompletionQueue { * See [BatchResult] for possible outcomes. */ fun runBatch(call: CPointer, ops: CPointer, nOps: ULong): BatchResult { + if (_shutdownDone.isCompleted) return BatchResult.CQShutdown + val completion = CallbackFuture() val tag = newCbTag(completion, OPS_COMPLETE_CB) @@ -194,8 +196,8 @@ private fun opsCompleteCb(functor: CPointer?, ok: private fun shutdownCb(functor: CPointer?, ok: Int) { val tag = functor!!.reinterpret() val cq = tag.pointed.user_data!!.asStableRef().get() - cq._shutdownDone.complete(Unit) cq._state.value = CompletionQueue.State.CLOSED + cq._shutdownDone.complete(Unit) grpc_completion_queue_destroy(cq.raw) } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.native.kt index dd60f03e1..9b275708c 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/GrpcContext.native.kt @@ -6,20 +6,20 @@ package kotlinx.rpc.grpc.internal import kotlin.coroutines.CoroutineContext -internal actual class GrpcContext +public actual class GrpcContext private val currentGrpcContext = GrpcContext() internal actual val CurrentGrpcContext: GrpcContext get() = currentGrpcContext -internal actual class GrpcContextElement : CoroutineContext.Element { +internal actual class GrpcContextElement(actual val grpcContext: GrpcContext) : CoroutineContext.Element { actual override val key: CoroutineContext.Key get() = Key actual companion object Key : CoroutineContext.Key { actual fun current(): GrpcContextElement { - return GrpcContextElement() + return GrpcContextElement(currentGrpcContext) } } } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeClientCall.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeClientCall.kt index 69c637b32..9a02dd6e3 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeClientCall.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeClientCall.kt @@ -23,7 +23,7 @@ import kotlinx.cinterop.toKString import kotlinx.cinterop.value import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableJob -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.Status import kotlinx.rpc.grpc.StatusCode import kotlinx.rpc.protobuf.input.stream.asInputStream @@ -81,7 +81,7 @@ internal class NativeClientCall( private var listener: Listener? = null private var halfClosed = false private var cancelled = false - private var closed = atomic(false) + private val closed = atomic(false) // tracks how many operations are in flight (not yet completed by the listener). // if 0 and we got a closeInfo (containing the status), there are no more ongoing operations. @@ -92,7 +92,7 @@ internal class NativeClientCall( // holds the received status information returned by the RECV_STATUS_ON_CLIENT batch. // if null, the call is still in progress. otherwise, the call can be closed as soon as inFlight is 0. - private val closeInfo = atomic?>(null) + private val closeInfo = atomic?>(null) // we currently don't buffer messages, so after one `sendMessage` call, ready turns false. (KRPC-192) private val ready = atomic(true) @@ -133,7 +133,9 @@ internal class NativeClientCall( val lst = checkNotNull(listener) { internalError("Not yet started") } // allows the managed channel to join for the call to finish. callJob.complete() - lst.onClose(info.first, info.second) + safeUserCode("Failed to call onClose.") { + lst.onClose(info.first, info.second) + } } } @@ -141,10 +143,9 @@ internal class NativeClientCall( * Sets the [closeInfo] and calls [tryToCloseCall]. * This is called as soon as the RECV_STATUS_ON_CLIENT batch (started with [startRecvStatus]) finished. */ - private fun markClosePending(status: Status, trailers: GrpcTrailers) { - if (closeInfo.compareAndSet(null, Pair(status, trailers))) { - tryToCloseCall() - } + private fun markClosePending(status: Status, trailers: GrpcMetadata) { + closeInfo.compareAndSet(null, Pair(status, trailers)) + tryToCloseCall() } /** @@ -153,17 +154,18 @@ internal class NativeClientCall( */ private fun turnReady() { if (ready.compareAndSet(expect = false, update = true)) { - listener?.onReady() + safeUserCode("Failed to call onReady.") { + listener?.onReady() + } } } override fun start( responseListener: Listener, - headers: GrpcTrailers, + headers: GrpcMetadata, ) { check(listener == null) { internalError("Already started") } - check(!cancelled) { internalError("Already cancelled.") } listener = responseListener @@ -198,6 +200,7 @@ internal class NativeClientCall( callResult.future.onComplete { success -> try { if (success) { + // if the batch doesn't succeed, this is reflected in the recv status op batch. onSuccess() } } finally { @@ -254,8 +257,9 @@ internal class NativeClientCall( is BatchResult.Submitted -> { callResult.future.onComplete { val details = statusDetails.toByteArray().toKString() - val status = Status(statusCode.value.toKotlin(), details, null) - val trailers = GrpcTrailers() + val kStatusCode = statusCode.value.toKotlin() + val status = Status(kStatusCode, details, null) + val trailers = GrpcMetadata() // cleanup grpc_slice_unref(statusDetails.readValue()) @@ -270,7 +274,7 @@ internal class NativeClientCall( BatchResult.CQShutdown -> { arena.clear() - markClosePending(Status(StatusCode.UNAVAILABLE, "Channel shutdown"), GrpcTrailers()) + markClosePending(Status(StatusCode.UNAVAILABLE, "Channel shutdown"), GrpcMetadata()) return false } @@ -278,7 +282,7 @@ internal class NativeClientCall( arena.clear() markClosePending( Status(StatusCode.INTERNAL, "Failed to start call: ${callResult.error}"), - GrpcTrailers() + GrpcMetadata() ) return false } @@ -306,7 +310,9 @@ internal class NativeClientCall( grpc_metadata_array_destroy(meta.ptr) arena.clear() }) { - // TODO: Send headers to listener + safeUserCode("Failed to call onHeaders.") { + listener?.onHeaders(GrpcMetadata()) + } } } @@ -319,7 +325,10 @@ internal class NativeClientCall( // limit numMessages to prevent potential stack overflows check(numMessages <= 16) { internalError("numMessages must be <= 16") } val listener = checkNotNull(listener) { internalError("Not yet started") } - check(!cancelled) { internalError("Already cancelled") } + if (cancelled) { + // no need to send message if the call got already cancelled. + return + } var remainingMessages = numMessages @@ -342,7 +351,9 @@ internal class NativeClientCall( val buf = recvPtr.value ?: return@runBatch val msg = methodDescriptor.getResponseMarshaller() .parse(buf.toKotlin().asInputStream()) - listener.onMessage(msg) + safeUserCode("Failed to call onClose.") { + listener.onMessage(msg) + } post() } } @@ -353,20 +364,26 @@ internal class NativeClientCall( override fun cancel(message: String?, cause: Throwable?) { cancelled = true - val message = if (cause != null) "$message: ${cause.message}" else message - cancelInternal(grpc_status_code.GRPC_STATUS_CANCELLED, message ?: "Call cancelled") + val status = Status(StatusCode.CANCELLED, message ?: "Call cancelled", cause) + // user side cancellation must always win over any other status (even if the call is already completed). + // this will also preserve the cancellation cause, which cannot be passed to the grpc-core. + closeInfo.value = Pair(status, GrpcMetadata()) + cancelInternal( + grpc_status_code.GRPC_STATUS_CANCELLED, + message ?: "Call cancelled with cause: ${cause?.message}" + ) } private fun cancelInternal(statusCode: grpc_status_code, message: String) { val cancelResult = grpc_call_cancel_with_status(raw, statusCode, message, null) if (cancelResult != grpc_call_error.GRPC_CALL_OK) { - markClosePending(Status(StatusCode.INTERNAL, "Failed to cancel call: $cancelResult"), GrpcTrailers()) + markClosePending(Status(StatusCode.INTERNAL, "Failed to cancel call: $cancelResult"), GrpcMetadata()) } } override fun halfClose() { check(!halfClosed) { internalError("Already half closed.") } - check(!cancelled) { internalError("Already cancelled.") } + if (cancelled) return halfClosed = true val arena = Arena() @@ -384,9 +401,10 @@ internal class NativeClientCall( override fun sendMessage(message: Request) { checkNotNull(listener) { internalError("Not yet started") } check(!halfClosed) { internalError("Already half closed.") } - check(!cancelled) { internalError("Already cancelled.") } check(isReady()) { internalError("Not yet ready.") } + if (cancelled) return + // set ready false, as only one message can be sent at a time. ready.value = false @@ -408,6 +426,18 @@ internal class NativeClientCall( turnReady() } } + + /** + * Safely executes the provided block of user code, catching any thrown exceptions or errors. + * If an exception is caught, it cancels the operation with the specified message and cause. + */ + private fun safeUserCode(cancelMsg: String, block: () -> Unit) { + try { + block() + } catch (e: Throwable) { + cancel(cancelMsg, e) + } + } } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeServerCall.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeServerCall.kt index c9d419f1f..00c58b72e 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeServerCall.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/NativeServerCall.kt @@ -16,11 +16,15 @@ import kotlinx.cinterop.CPointerVar import kotlinx.cinterop.ExperimentalForeignApi import kotlinx.cinterop.IntVar import kotlinx.cinterop.alloc +import kotlinx.cinterop.allocArray +import kotlinx.cinterop.convert +import kotlinx.cinterop.get import kotlinx.cinterop.ptr -import kotlinx.cinterop.readValue import kotlinx.cinterop.value -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.Status +import kotlinx.rpc.grpc.StatusCode +import kotlinx.rpc.grpc.StatusException import kotlinx.rpc.protobuf.input.stream.asInputStream import kotlinx.rpc.protobuf.input.stream.asSource import libkgrpc.GRPC_OP_RECV_CLOSE_ON_SERVER @@ -33,7 +37,6 @@ import libkgrpc.grpc_byte_buffer_destroy import libkgrpc.grpc_call_cancel_with_status import libkgrpc.grpc_call_unref import libkgrpc.grpc_op -import libkgrpc.grpc_slice import libkgrpc.grpc_slice_unref import libkgrpc.grpc_status_code import kotlin.concurrent.Volatile @@ -64,9 +67,15 @@ internal class NativeServerCall( private val callbackMutex = ReentrantLock() private var initialized = false private var cancelled = false + private var closed = false private val finalized = atomic(false) - // Tracks whether at least one request message has been received on this call. + // tracks whether the initial metadata has been sent. + // this is used to determine if we have to send the initial metadata + // when we try to close the call. + private var sentInitialMetadata = false + + // tracks whether at least one request message has been received on this call. private var receivedFirstMessage = false // we currently don't buffer messages, so after one `sendMessage` call, ready turns false. (KRPC-192) @@ -135,6 +144,7 @@ internal class NativeServerCall( } fun cancel(status: grpc_status_code, message: String) { + cancelled = true grpc_call_cancel_with_status(raw, status, message, null) } @@ -160,6 +170,9 @@ internal class NativeServerCall( cleanup: () -> Unit = {}, onSuccess: () -> Unit = {}, ) { + // if we are already closed, we cannot run any more batches. + if (closed || cancelled) return cleanup() + when (val result = cq.runBatch(raw, ops, nOps)) { is BatchResult.Submitted -> { result.future.onComplete { @@ -235,7 +248,7 @@ internal class NativeServerCall( } } - override fun sendHeaders(headers: GrpcTrailers) { + override fun sendHeaders(headers: GrpcMetadata) { check(initialized) { internalError("Call not initialized") } val arena = Arena() // TODO: Implement header metadata operation @@ -245,6 +258,7 @@ internal class NativeServerCall( data.send_initial_metadata.metadata = null } + sentInitialMetadata = true runBatch(op.ptr, 1u, cleanup = { arena.clear() }) { // nothing to do here } @@ -256,44 +270,56 @@ internal class NativeServerCall( val methodDescriptor = checkNotNull(methodDescriptor) { internalError("Method descriptor not set") } val arena = Arena() - val inputStream = methodDescriptor.getResponseMarshaller().stream(message) - val byteBuffer = inputStream.asSource().toGrpcByteBuffer() - ready.value = false - - val op = arena.alloc { - op = GRPC_OP_SEND_MESSAGE - data.send_message.send_message = byteBuffer - } + tryRun { + val inputStream = methodDescriptor.getResponseMarshaller().stream(message) + val byteBuffer = inputStream.asSource().toGrpcByteBuffer() + ready.value = false + + val op = arena.alloc { + op = GRPC_OP_SEND_MESSAGE + data.send_message.send_message = byteBuffer + } - runBatch(op.ptr, 1u, cleanup = { - arena.clear() - grpc_byte_buffer_destroy(byteBuffer) - }) { - turnReady() + runBatch(op.ptr, 1u, cleanup = { + arena.clear() + grpc_byte_buffer_destroy(byteBuffer) + }) { + turnReady() + } } } - override fun close(status: Status, trailers: GrpcTrailers) { + override fun close(status: Status, trailers: GrpcMetadata) { check(initialized) { internalError("Call not initialized") } + val arena = Arena() - val details = status.getDescription()?.let { - arena.alloc { - it.toGrpcSlice() - } - } - val op = arena.alloc { - op = GRPC_OP_SEND_STATUS_FROM_SERVER - data.send_status_from_server.status = status.statusCode.toRawCallAllocation() - data.send_status_from_server.status_details = details?.ptr - data.send_status_from_server.trailing_metadata_count = 0u - data.send_status_from_server.trailing_metadata = null + val details = status.getDescription()?.toGrpcSlice() + val detailsPtr = details?.getPointer(arena) + + val nOps = if (sentInitialMetadata) 1uL else 2uL + + val ops = arena.allocArray(nOps.convert()) + + ops[0].op = GRPC_OP_SEND_STATUS_FROM_SERVER + ops[0].data.send_status_from_server.status = status.statusCode.toRaw() + ops[0].data.send_status_from_server.status_details = detailsPtr + ops[0].data.send_status_from_server.trailing_metadata_count = 0u + ops[0].data.send_status_from_server.trailing_metadata = null + + if (!sentInitialMetadata) { + // if we haven't sent GRPC_OP_SEND_INITIAL_METADATA yet, + // so we must do it together with the close operation. + ops[1].op = GRPC_OP_SEND_INITIAL_METADATA + ops[1].data.send_initial_metadata.count = 0u + ops[1].data.send_initial_metadata.metadata = null } - runBatch(op.ptr, 1u, cleanup = { - if (details != null) grpc_slice_unref(details.readValue()) + runBatch(ops, nOps, cleanup = { + if (details != null) grpc_slice_unref(details) arena.clear() }) { + closed = true // nothing to do here } } @@ -306,6 +332,25 @@ internal class NativeServerCall( val methodDescriptor = checkNotNull(methodDescriptor) { internalError("Method descriptor not set") } return methodDescriptor } + + + private inline fun tryRun(crossinline block: () -> T): T { + try { + return block() + } catch (e: Throwable) { + // TODO: Log internal error as warning + val status = when (e) { + is StatusException -> e.getStatus() + else -> Status( + StatusCode.INTERNAL, + description = "Internal error, so canceling the stream", + cause = e + ) + } + cancel(status.statusCode.toRaw(), status.getDescription() ?: "Unknown error") + throw StatusException(status, trailers = null) + } + } } /** diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.native.kt index bda5a17be..e57f2eb96 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/ServerCall.native.kt @@ -4,7 +4,7 @@ package kotlinx.rpc.grpc.internal -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.Status import kotlinx.rpc.internal.utils.InternalRpcApi @@ -12,16 +12,16 @@ import kotlinx.rpc.internal.utils.InternalRpcApi public actual fun interface ServerCallHandler { public actual fun startCall( call: ServerCall, - headers: GrpcTrailers, + headers: GrpcMetadata, ): ServerCall.Listener } @InternalRpcApi public actual abstract class ServerCall { public actual abstract fun request(numMessages: Int) - public actual abstract fun sendHeaders(headers: GrpcTrailers) + public actual abstract fun sendHeaders(headers: GrpcMetadata) public actual abstract fun sendMessage(message: Response) - public actual abstract fun close(status: Status, trailers: GrpcTrailers) + public actual abstract fun close(status: Status, trailers: GrpcMetadata) public actual open fun isReady(): Boolean { // Default implementation returns true - subclasses can override if they need flow control diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/serverCallTags.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/serverCallTags.kt index f22bd1f0e..ef24ff71e 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/serverCallTags.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/serverCallTags.kt @@ -15,7 +15,7 @@ import kotlinx.cinterop.alloc import kotlinx.cinterop.cValue import kotlinx.cinterop.ptr import kotlinx.cinterop.value -import kotlinx.rpc.grpc.GrpcTrailers +import kotlinx.rpc.grpc.GrpcMetadata import kotlinx.rpc.grpc.HandlerRegistry import libkgrpc.gpr_timespec import libkgrpc.grpc_call_details @@ -62,7 +62,7 @@ internal class RegisteredServerCallTag( // ownership of the core call is transferred to the NativeServerCall. val call = NativeServerCall(rawCall.value!!, cq, method.getMethodDescriptor()) // TODO: Turn metadata into a kotlin GrpcTrailers. - val trailers = GrpcTrailers() + val trailers = GrpcMetadata() // start the actual call. val listener = method.getServerCallHandler().startCall(call, trailers) call.setListener(listener) @@ -141,7 +141,7 @@ internal class LookupServerCallTag( definition.getMethodDescriptor() as MethodDescriptor ) // TODO: Turn metadata into a kotlin GrpcTrailers. - val metadata = GrpcTrailers() + val metadata = GrpcMetadata() val listener = callHandler.startCall(call, metadata) call.setListener(listener) } diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/utils.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/utils.kt index c9f8ebc1e..289b5b2e5 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/utils.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/internal/utils.kt @@ -193,7 +193,7 @@ internal fun grpc_status_code.toKotlin(): StatusCode = when (this) { else -> error("Invalid status code: $this") } -internal fun StatusCode.toRawCallAllocation(): grpc_status_code = when (this) { +internal fun StatusCode.toRaw(): grpc_status_code = when (this) { StatusCode.OK -> grpc_status_code.GRPC_STATUS_OK StatusCode.CANCELLED -> grpc_status_code.GRPC_STATUS_CANCELLED StatusCode.UNKNOWN -> grpc_status_code.GRPC_STATUS_UNKNOWN diff --git a/grpc/grpc-ktor-server/src/commonMain/kotlin/kotlinx/rpc/grpc/ktor/server/Server.kt b/grpc/grpc-ktor-server/src/commonMain/kotlin/kotlinx/rpc/grpc/ktor/server/Server.kt index 6642885dc..339d79af6 100644 --- a/grpc/grpc-ktor-server/src/commonMain/kotlin/kotlinx/rpc/grpc/ktor/server/Server.kt +++ b/grpc/grpc-ktor-server/src/commonMain/kotlin/kotlinx/rpc/grpc/ktor/server/Server.kt @@ -4,17 +4,13 @@ package kotlinx.rpc.grpc.ktor.server -import io.ktor.server.application.Application -import io.ktor.server.application.ApplicationStopped -import io.ktor.server.application.ApplicationStopping -import io.ktor.server.application.log -import io.ktor.server.config.getAs -import io.ktor.util.AttributeKey +import io.ktor.server.application.* +import io.ktor.server.config.* +import io.ktor.util.* import kotlinx.rpc.RpcServer import kotlinx.rpc.grpc.GrpcServer +import kotlinx.rpc.grpc.GrpcServerConfiguration import kotlinx.rpc.grpc.ServerBuilder -import kotlinx.rpc.grpc.codec.EmptyMessageCodecResolver -import kotlinx.rpc.grpc.codec.MessageCodecResolver @Suppress("ConstPropertyName") public object GrpcConfigKeys { @@ -51,8 +47,7 @@ public val GrpcServerKey: AttributeKey = AttributeKey("G */ public fun Application.grpc( port: Int = environment.config.propertyOrNull(GrpcConfigKeys.grpcHostPortPath)?.getAs() ?: 8001, - messageCodecResolver: MessageCodecResolver = EmptyMessageCodecResolver, - configure: ServerBuilder<*>.() -> Unit = {}, + configure: GrpcServerConfiguration.() -> Unit = {}, builder: RpcServer.() -> Unit, ): GrpcServer { if (attributes.contains(GrpcServerKey)) { @@ -64,7 +59,6 @@ public fun Application.grpc( newServer = true GrpcServer( port = port, - messageCodecResolver = messageCodecResolver, parentContext = coroutineContext, configure = configure, builder = builder, diff --git a/grpc/grpc-ktor-server/src/jvmTest/kotlin/kotlinx/rpc/grpc/ktor/server/test/TestServer.kt b/grpc/grpc-ktor-server/src/jvmTest/kotlin/kotlinx/rpc/grpc/ktor/server/test/TestServer.kt index f299fd649..6482d41de 100644 --- a/grpc/grpc-ktor-server/src/jvmTest/kotlin/kotlinx/rpc/grpc/ktor/server/test/TestServer.kt +++ b/grpc/grpc-ktor-server/src/jvmTest/kotlin/kotlinx/rpc/grpc/ktor/server/test/TestServer.kt @@ -4,12 +4,11 @@ package kotlinx.rpc.grpc.ktor.server.test -import io.ktor.server.testing.testApplication +import io.ktor.server.testing.* import kotlinx.rpc.grpc.GrpcClient -import kotlin.test.Test import kotlinx.rpc.grpc.ktor.server.grpc -import kotlinx.rpc.registerService import kotlinx.rpc.withService +import kotlin.test.Test import kotlin.test.assertEquals import kotlin.time.Duration.Companion.minutes @@ -33,7 +32,7 @@ class TestServer { startApplication() val client = GrpcClient("localhost", PORT) { - usePlaintext() + credentials = plaintext() } val response = client.withService().sayHello(Hello { message = "Hello" })