diff --git a/Sources/GRPCCore/Call/Server/RPCRouter.swift b/Sources/GRPCCore/Call/Server/RPCRouter.swift index a574a1186..d40bd71c4 100644 --- a/Sources/GRPCCore/Call/Server/RPCRouter.swift +++ b/Sources/GRPCCore/Call/Server/RPCRouter.swift @@ -22,6 +22,8 @@ /// the router has a handler for a method with ``hasHandler(forMethod:)`` or get a list of all /// methods with handlers registered by calling ``methods``. You can also remove the handler for a /// given method by calling ``removeHandler(forMethod:)``. +/// You can also register any interceptors that you want applied to registered handlers via the +/// ``registerInterceptors(pipeline:)`` method. /// /// In most cases you won't need to interact with the router directly. Instead you should register /// your services with ``GRPCServer/init(transport:services:interceptors:)`` which will in turn @@ -82,7 +84,8 @@ public struct RPCRouter: Sendable { } @usableFromInline - private(set) var handlers: [MethodDescriptor: RPCHandler] + private(set) var handlers: + [MethodDescriptor: (handler: RPCHandler, interceptors: [any ServerInterceptor])] /// Creates a new router with no methods registered. public init() { @@ -126,12 +129,13 @@ public struct RPCRouter: Sendable { _ context: ServerContext ) async throws -> StreamingServerResponse ) { - self.handlers[descriptor] = RPCHandler( + let handler = RPCHandler( method: descriptor, deserializer: deserializer, serializer: serializer, handler: handler ) + self.handlers[descriptor] = (handler, []) } /// Removes any handler registered for the specified method. @@ -142,6 +146,25 @@ public struct RPCRouter: Sendable { public mutating func removeHandler(forMethod descriptor: MethodDescriptor) -> Bool { return self.handlers.removeValue(forKey: descriptor) != nil } + + /// Registers applicable interceptors to all currently-registered handlers. + /// + /// - Important: Calling this method will apply the interceptors only to existing handlers. Any handlers registered via + /// ``registerHandler(forMethod:deserializer:serializer:handler:)`` _after_ calling this method will not have + /// any interceptors applied to them. If you want to make sure all registered methods have any applicable interceptors applied, + /// only call this method _after_ you have registered all handlers. + /// - Parameter pipeline: The interceptor pipeline operations to register to all currently-registered handlers. The order of the + /// interceptors matters. + /// - SeeAlso: ``ServerInterceptorPipelineOperation``. + @inlinable + public mutating func registerInterceptors(pipeline: [ServerInterceptorPipelineOperation]) { + for descriptor in self.handlers.keys { + let applicableOperations = pipeline.filter { $0.applies(to: descriptor) } + if !applicableOperations.isEmpty { + self.handlers[descriptor]?.interceptors = applicableOperations.map { $0.interceptor } + } + } + } } extension RPCRouter { @@ -150,10 +173,9 @@ extension RPCRouter { RPCAsyncSequence, RPCWriter.Closable >, - context: ServerContext, - interceptors: [any ServerInterceptor] + context: ServerContext ) async { - if let handler = self.handlers[stream.descriptor] { + if let (handler, interceptors) = self.handlers[stream.descriptor] { await handler.handle(stream: stream, context: context, interceptors: interceptors) } else { // If this throws then the stream must be closed which we can't do anything about, so ignore diff --git a/Sources/GRPCCore/Call/Server/ServerInterceptor.swift b/Sources/GRPCCore/Call/Server/ServerInterceptor.swift index 3c71cd3e5..e90266862 100644 --- a/Sources/GRPCCore/Call/Server/ServerInterceptor.swift +++ b/Sources/GRPCCore/Call/Server/ServerInterceptor.swift @@ -21,10 +21,11 @@ /// been returned from a service. They are typically used for cross-cutting concerns like filtering /// requests, validating messages, logging additional data, and tracing. /// -/// Interceptors are registered with the server apply to all RPCs. If you need to modify the -/// behavior of an interceptor on a per-RPC basis then you can use the -/// ``ServerContext/descriptor`` to determine which RPC is being called and -/// conditionalise behavior accordingly. +/// Interceptors can be registered with the server either directly or via ``ServerInterceptorPipelineOperation``s. +/// You may register them for all services registered with a server, for RPCs directed to specific services, or +/// for RPCs directed to specific methods. If you need to modify the behavior of an interceptor on a +/// per-RPC basis in more detail, then you can use the ``ServerContext/descriptor`` to determine +/// which RPC is being called and conditionalise behavior accordingly. /// /// ## RPC filtering /// @@ -33,19 +34,19 @@ /// demonstrates this. /// /// ```swift -/// struct AuthServerInterceptor: Sendable { +/// struct AuthServerInterceptor: ServerInterceptor { /// let isAuthorized: @Sendable (String, MethodDescriptor) async throws -> Void /// /// func intercept( /// request: StreamingServerRequest, -/// context: ServerInterceptorContext, +/// context: ServerContext, /// next: @Sendable ( /// _ request: StreamingServerRequest, -/// _ context: ServerInterceptorContext +/// _ context: ServerContext /// ) async throws -> StreamingServerResponse /// ) async throws -> StreamingServerResponse { /// // Extract the auth token. -/// guard let token = request.metadata["authorization"] else { +/// guard let token = request.metadata[stringValues: "authorization"].first(where: { _ in true }) else { /// throw RPCError(code: .unauthenticated, message: "Not authenticated") /// } /// diff --git a/Sources/GRPCCore/Call/Server/ServerInterceptorPipelineOperation.swift b/Sources/GRPCCore/Call/Server/ServerInterceptorPipelineOperation.swift new file mode 100644 index 000000000..3d2731fd4 --- /dev/null +++ b/Sources/GRPCCore/Call/Server/ServerInterceptorPipelineOperation.swift @@ -0,0 +1,98 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// A `ServerInterceptorPipelineOperation` describes to which RPCs a server interceptor should be applied. +/// +/// You can configure a server interceptor to be applied to: +/// - all RPCs and services; +/// - requests directed only to specific services registered with your server; or +/// - requests directed only to specific methods (of a specific service). +/// +/// - SeeAlso: ``ServerInterceptor`` for more information on server interceptors. +public struct ServerInterceptorPipelineOperation: Sendable { + /// The subject of a ``ServerInterceptorPipelineOperation``. + /// The subject of an interceptor can either be all services and methods, only specific services, or only specific methods. + public struct Subject: Sendable { + internal enum Wrapped: Sendable { + case all + case services(Set) + case methods(Set) + } + + private let wrapped: Wrapped + + /// An operation subject specifying an interceptor that applies to all RPCs across all services will be registered with this server. + public static var all: Self { .init(wrapped: .all) } + + /// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified services. + /// - Parameters: + /// - services: The list of service names for which this interceptor should intercept RPCs. + /// - Returns: A ``ServerInterceptorPipelineOperation``. + public static func services(_ services: Set) -> Self { + Self(wrapped: .services(services)) + } + + /// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified service methods. + /// - Parameters: + /// - methods: The list of method descriptors for which this interceptor should intercept RPCs. + /// - Returns: A ``ServerInterceptorPipelineOperation``. + public static func methods(_ methods: Set) -> Self { + Self(wrapped: .methods(methods)) + } + + @usableFromInline + internal func applies(to descriptor: MethodDescriptor) -> Bool { + switch self.wrapped { + case .all: + return true + + case .services(let services): + return services.map({ $0.fullyQualifiedService }).contains(descriptor.service) + + case .methods(let methods): + return methods.contains(descriptor) + } + } + } + + /// The interceptor specified for this operation. + public let interceptor: any ServerInterceptor + + @usableFromInline + internal let subject: Subject + + private init(interceptor: any ServerInterceptor, appliesTo: Subject) { + self.interceptor = interceptor + self.subject = appliesTo + } + + /// Create an operation, specifying which ``ServerInterceptor`` to apply and to which ``Subject``. + /// - Parameters: + /// - interceptor: The ``ServerInterceptor`` to register with the server. + /// - subject: The ``Subject`` to which the `interceptor` applies. + /// - Returns: A ``ServerInterceptorPipelineOperation``. + public static func apply(_ interceptor: any ServerInterceptor, to subject: Subject) -> Self { + Self(interceptor: interceptor, appliesTo: subject) + } + + /// Returns whether this ``ServerInterceptorPipelineOperation`` applies to the given `descriptor`. + /// - Parameter descriptor: A ``MethodDescriptor`` for which to test whether this interceptor applies. + /// - Returns: `true` if this interceptor applies to the given `descriptor`, or `false` otherwise. + @inlinable + internal func applies(to descriptor: MethodDescriptor) -> Bool { + self.subject.applies(to: descriptor) + } +} diff --git a/Sources/GRPCCore/GRPCServer.swift b/Sources/GRPCCore/GRPCServer.swift index 8ba69985d..6ff82b9dd 100644 --- a/Sources/GRPCCore/GRPCServer.swift +++ b/Sources/GRPCCore/GRPCServer.swift @@ -78,14 +78,6 @@ public final class GRPCServer: Sendable { /// The services registered which the server is serving. private let router: RPCRouter - /// A collection of ``ServerInterceptor`` implementations which are applied to all accepted - /// RPCs. - /// - /// RPCs are intercepted in the order that interceptors are added. That is, a request received - /// from the client will first be intercepted by the first added interceptor followed by the - /// second, and so on. - private let interceptors: [any ServerInterceptor] - /// The state of the server. private let state: Mutex @@ -154,33 +146,46 @@ public final class GRPCServer: Sendable { services: [any RegistrableRPCService], interceptors: [any ServerInterceptor] = [] ) { - var router = RPCRouter() - for service in services { - service.registerMethods(with: &router) - } - - self.init(transport: transport, router: router, interceptors: interceptors) + self.init( + transport: transport, + services: services, + interceptorPipeline: interceptors.map { .apply($0, to: .all) } + ) } /// Creates a new server with no resources. /// /// - Parameters: /// - transport: The transport the server should listen on. - /// - router: A ``RPCRouter`` used by the server to route accepted streams to method handlers. - /// - interceptors: A collection of interceptors providing cross-cutting functionality to each + /// - services: Services offered by the server. + /// - interceptorPipeline: A collection of interceptors providing cross-cutting functionality to each /// accepted RPC. The order in which interceptors are added reflects the order in which they /// are called. The first interceptor added will be the first interceptor to intercept each /// request. The last interceptor added will be the final interceptor to intercept each /// request before calling the appropriate handler. - public init( + public convenience init( transport: any ServerTransport, - router: RPCRouter, - interceptors: [any ServerInterceptor] = [] + services: [any RegistrableRPCService], + interceptorPipeline: [ServerInterceptorPipelineOperation] ) { + var router = RPCRouter() + for service in services { + service.registerMethods(with: &router) + } + router.registerInterceptors(pipeline: interceptorPipeline) + + self.init(transport: transport, router: router) + } + + /// Creates a new server with no resources. + /// + /// - Parameters: + /// - transport: The transport the server should listen on. + /// - router: A ``RPCRouter`` used by the server to route accepted streams to method handlers. + public init(transport: any ServerTransport, router: RPCRouter) { self.state = Mutex(.notStarted) self.transport = transport self.router = router - self.interceptors = interceptors } /// Starts the server and runs until the registered transport has closed. @@ -206,7 +211,7 @@ public final class GRPCServer: Sendable { do { try await transport.listen { stream, context in - await self.router.handle(stream: stream, context: context, interceptors: self.interceptors) + await self.router.handle(stream: stream, context: context) } } catch { throw RuntimeError( diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift index 0533fe26b..fe8d301aa 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift @@ -333,7 +333,9 @@ final class ServerRPCExecutorTests: XCTestCase { func testThrowingInterceptor() async throws { let harness = ServerRPCExecutorTestHarness( - interceptors: [.throwError(RPCError(code: .unavailable, message: "Unavailable"))] + interceptors: [ + .throwError(RPCError(code: .unavailable, message: "Unavailable")) + ] ) try await harness.execute(handler: .echo) { inbound in diff --git a/Tests/GRPCCoreTests/GRPCServerTests.swift b/Tests/GRPCCoreTests/GRPCServerTests.swift index b7866c80c..9b20785d5 100644 --- a/Tests/GRPCCoreTests/GRPCServerTests.swift +++ b/Tests/GRPCCoreTests/GRPCServerTests.swift @@ -16,19 +16,20 @@ import GRPCCore import GRPCInProcessTransport +import Testing import XCTest final class GRPCServerTests: XCTestCase { func withInProcessClientConnectedToServer( services: [any RegistrableRPCService], - interceptors: [any ServerInterceptor] = [], + interceptorPipeline: [ServerInterceptorPipelineOperation] = [], _ body: (InProcessTransport.Client, GRPCServer) async throws -> Void ) async throws { let inProcess = InProcessTransport() let server = GRPCServer( transport: inProcess.server, services: services, - interceptors: interceptors + interceptorPipeline: interceptorPipeline ) try await withThrowingTaskGroup(of: Void.self) { group in @@ -219,10 +220,10 @@ final class GRPCServerTests: XCTestCase { try await self.withInProcessClientConnectedToServer( services: [BinaryEcho()], - interceptors: [ - .requestCounter(counter1), - .rejectAll(with: RPCError(code: .unavailable, message: "")), - .requestCounter(counter2), + interceptorPipeline: [ + .apply(.requestCounter(counter1), to: .all), + .apply(.rejectAll(with: RPCError(code: .unavailable, message: "")), to: .all), + .apply(.requestCounter(counter2), to: .all), ] ) { client, _ in try await client.withStream( @@ -248,7 +249,7 @@ final class GRPCServerTests: XCTestCase { try await self.withInProcessClientConnectedToServer( services: [BinaryEcho()], - interceptors: [.requestCounter(counter)] + interceptorPipeline: [.apply(.requestCounter(counter), to: .all)] ) { client, _ in try await client.withStream( descriptor: MethodDescriptor(service: "not", method: "implemented"), @@ -374,3 +375,243 @@ final class GRPCServerTests: XCTestCase { } } } + +@Suite("GRPC Server Tests") +struct ServerTests { + @Test("Interceptors are applied only to specified services") + func testInterceptorsAreAppliedToSpecifiedServices() async throws { + let onlyBinaryEchoCounter = AtomicCounter() + let allServicesCounter = AtomicCounter() + let onlyHelloWorldCounter = AtomicCounter() + let bothServicesCounter = AtomicCounter() + + try await self.withInProcessClientConnectedToServer( + services: [BinaryEcho(), HelloWorld()], + interceptorPipeline: [ + .apply( + .requestCounter(onlyBinaryEchoCounter), + to: .services([BinaryEcho.serviceDescriptor]) + ), + .apply(.requestCounter(allServicesCounter), to: .all), + .apply( + .requestCounter(onlyHelloWorldCounter), + to: .services([HelloWorld.serviceDescriptor]) + ), + .apply( + .requestCounter(bothServicesCounter), + to: .services([BinaryEcho.serviceDescriptor, HelloWorld.serviceDescriptor]) + ), + ] + ) { client, _ in + // Make a request to the `BinaryEcho` service and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.withStream( + descriptor: BinaryEcho.Methods.get, + options: .defaults + ) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message(Array("hello".utf8))) + await stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + self.assertMetadata(metadata) + + let message = try await responseParts.next() + self.assertMessage(message) { + #expect($0 == Array("hello".utf8)) + } + + let status = try await responseParts.next() + self.assertStatus(status) { status, _ in + #expect(status.code == .ok, Comment(rawValue: status.description)) + } + } + + #expect(onlyBinaryEchoCounter.value == 1) + #expect(allServicesCounter.value == 1) + #expect(onlyHelloWorldCounter.value == 0) + #expect(bothServicesCounter.value == 1) + + // Now, make a request to the `HelloWorld` service and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.withStream( + descriptor: HelloWorld.Methods.sayHello, + options: .defaults + ) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message(Array("Swift".utf8))) + await stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + self.assertMetadata(metadata) + + let message = try await responseParts.next() + self.assertMessage(message) { + #expect($0 == Array("Hello, Swift!".utf8)) + } + + let status = try await responseParts.next() + self.assertStatus(status) { status, _ in + #expect(status.code == .ok, Comment(rawValue: status.description)) + } + } + + #expect(onlyBinaryEchoCounter.value == 1) + #expect(allServicesCounter.value == 2) + #expect(onlyHelloWorldCounter.value == 1) + #expect(bothServicesCounter.value == 2) + } + } + + @Test("Interceptors are applied only to specified methods") + func testInterceptorsAreAppliedToSpecifiedMethods() async throws { + let onlyBinaryEchoGetCounter = AtomicCounter() + let onlyBinaryEchoCollectCounter = AtomicCounter() + let bothBinaryEchoMethodsCounter = AtomicCounter() + let allMethodsCounter = AtomicCounter() + + try await self.withInProcessClientConnectedToServer( + services: [BinaryEcho()], + interceptorPipeline: [ + .apply( + .requestCounter(onlyBinaryEchoGetCounter), + to: .methods([BinaryEcho.Methods.get]) + ), + .apply(.requestCounter(allMethodsCounter), to: .all), + .apply( + .requestCounter(onlyBinaryEchoCollectCounter), + to: .methods([BinaryEcho.Methods.collect]) + ), + .apply( + .requestCounter(bothBinaryEchoMethodsCounter), + to: .methods([BinaryEcho.Methods.get, BinaryEcho.Methods.collect]) + ), + ] + ) { client, _ in + // Make a request to the `BinaryEcho/get` method and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.withStream( + descriptor: BinaryEcho.Methods.get, + options: .defaults + ) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message(Array("hello".utf8))) + await stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + self.assertMetadata(metadata) + + let message = try await responseParts.next() + self.assertMessage(message) { + #expect($0 == Array("hello".utf8)) + } + + let status = try await responseParts.next() + self.assertStatus(status) { status, _ in + #expect(status.code == .ok, Comment(rawValue: status.description)) + } + } + + #expect(onlyBinaryEchoGetCounter.value == 1) + #expect(allMethodsCounter.value == 1) + #expect(onlyBinaryEchoCollectCounter.value == 0) + #expect(bothBinaryEchoMethodsCounter.value == 1) + + // Now, make a request to the `BinaryEcho/collect` method and assert that only + // the counters associated to interceptors that apply to it are incremented. + try await client.withStream( + descriptor: BinaryEcho.Methods.collect, + options: .defaults + ) { stream in + try await stream.outbound.write(.metadata([:])) + try await stream.outbound.write(.message(Array("hello".utf8))) + await stream.outbound.finish() + + var responseParts = stream.inbound.makeAsyncIterator() + let metadata = try await responseParts.next() + self.assertMetadata(metadata) + + let message = try await responseParts.next() + self.assertMessage(message) { + #expect($0 == Array("hello".utf8)) + } + + let status = try await responseParts.next() + self.assertStatus(status) { status, _ in + #expect(status.code == .ok, Comment(rawValue: status.description)) + } + } + + #expect(onlyBinaryEchoGetCounter.value == 1) + #expect(allMethodsCounter.value == 2) + #expect(onlyBinaryEchoCollectCounter.value == 1) + #expect(bothBinaryEchoMethodsCounter.value == 2) + } + } + + func withInProcessClientConnectedToServer( + services: [any RegistrableRPCService], + interceptorPipeline: [ServerInterceptorPipelineOperation] = [], + _ body: (InProcessTransport.Client, GRPCServer) async throws -> Void + ) async throws { + let inProcess = InProcessTransport() + let server = GRPCServer( + transport: inProcess.server, + services: services, + interceptorPipeline: interceptorPipeline + ) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await server.serve() + } + + group.addTask { + try await inProcess.client.connect() + } + + try await body(inProcess.client, server) + inProcess.client.beginGracefulShutdown() + server.beginGracefulShutdown() + } + } + + func assertMetadata( + _ part: RPCResponsePart?, + metadataHandler: (Metadata) -> Void = { _ in } + ) { + switch part { + case .some(.metadata(let metadata)): + metadataHandler(metadata) + default: + Issue.record("Expected '.metadata' but found '\(String(describing: part))'") + } + } + + func assertMessage( + _ part: RPCResponsePart?, + messageHandler: ([UInt8]) -> Void = { _ in } + ) { + switch part { + case .some(.message(let message)): + messageHandler(message) + default: + Issue.record("Expected '.message' but found '\(String(describing: part))'") + } + } + + func assertStatus( + _ part: RPCResponsePart?, + statusHandler: (Status, Metadata) -> Void = { _, _ in } + ) { + switch part { + case .some(.status(let status, let metadata)): + statusHandler(status, metadata) + default: + Issue.record("Expected '.status' but found '\(String(describing: part))'") + } + } +} diff --git a/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift b/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift index 4356fdc1f..fdb869d1c 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift @@ -22,13 +22,13 @@ extension ServerInterceptor where Self == RejectAllServerInterceptor { } static func throwError(_ error: RPCError) -> Self { - return RejectAllServerInterceptor(error: error, throw: true) + RejectAllServerInterceptor(error: error, throw: true) } } extension ServerInterceptor where Self == RequestCountingServerInterceptor { static func requestCounter(_ counter: AtomicCounter) -> Self { - return RequestCountingServerInterceptor(counter: counter) + RequestCountingServerInterceptor(counter: counter) } } diff --git a/Tests/GRPCCoreTests/Test Utilities/Services/BinaryEcho.swift b/Tests/GRPCCoreTests/Test Utilities/Services/BinaryEcho.swift index 3859eec24..8d0ece3c7 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Services/BinaryEcho.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Services/BinaryEcho.swift @@ -14,9 +14,10 @@ * limitations under the License. */ import GRPCCore -import XCTest struct BinaryEcho: RegistrableRPCService { + static let serviceDescriptor = ServiceDescriptor(package: "echo", service: "Echo") + func get( _ request: ServerRequest<[UInt8]> ) async throws -> ServerResponse<[UInt8]> { @@ -96,9 +97,21 @@ struct BinaryEcho: RegistrableRPCService { } enum Methods { - static let get = MethodDescriptor(service: "echo.Echo", method: "Get") - static let collect = MethodDescriptor(service: "echo.Echo", method: "Collect") - static let expand = MethodDescriptor(service: "echo.Echo", method: "Expand") - static let update = MethodDescriptor(service: "echo.Echo", method: "Update") + static let get = MethodDescriptor( + service: BinaryEcho.serviceDescriptor.fullyQualifiedService, + method: "Get" + ) + static let collect = MethodDescriptor( + service: BinaryEcho.serviceDescriptor.fullyQualifiedService, + method: "Collect" + ) + static let expand = MethodDescriptor( + service: BinaryEcho.serviceDescriptor.fullyQualifiedService, + method: "Expand" + ) + static let update = MethodDescriptor( + service: BinaryEcho.serviceDescriptor.fullyQualifiedService, + method: "Update" + ) } } diff --git a/Tests/GRPCCoreTests/Test Utilities/Services/HelloWorld.swift b/Tests/GRPCCoreTests/Test Utilities/Services/HelloWorld.swift new file mode 100644 index 000000000..01501e0bb --- /dev/null +++ b/Tests/GRPCCoreTests/Test Utilities/Services/HelloWorld.swift @@ -0,0 +1,49 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import GRPCCore + +struct HelloWorld: RegistrableRPCService { + static let serviceDescriptor = ServiceDescriptor(package: "helloworld", service: "HelloWorld") + + func sayHello( + _ request: ServerRequest<[UInt8]> + ) async throws -> ServerResponse<[UInt8]> { + let name = String(bytes: request.message, encoding: .utf8) ?? "world" + return ServerResponse(message: Array("Hello, \(name)!".utf8), metadata: []) + } + + func registerMethods(with router: inout RPCRouter) { + let serializer = IdentitySerializer() + let deserializer = IdentityDeserializer() + + router.registerHandler( + forMethod: Methods.sayHello, + deserializer: deserializer, + serializer: serializer + ) { streamRequest, context in + let singleRequest = try await ServerRequest(stream: streamRequest) + let singleResponse = try await self.sayHello(singleRequest) + return StreamingServerResponse(single: singleResponse) + } + } + + enum Methods { + static let sayHello = MethodDescriptor( + service: HelloWorld.serviceDescriptor.fullyQualifiedService, + method: "SayHello" + ) + } +}