diff --git a/Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor.swift b/Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor.swift index 9a1da1f7d..ff0b30cdb 100644 --- a/Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Client/Internal/ClientRPCExecutor.swift @@ -186,6 +186,8 @@ extension ClientRPCExecutor { } } catch let error as RPCError { return StreamingClientResponse(error: error) + } catch let error as RPCErrorConvertible { + return StreamingClientResponse(error: RPCError(error)) } catch let other { let error = RPCError(code: .unknown, message: "", cause: other) return StreamingClientResponse(error: error) diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift index e2184de14..dee3303d5 100644 --- a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift @@ -330,6 +330,8 @@ extension ServerRPCExecutor { } } catch let error as RPCError { return StreamingServerResponse(error: error) + } catch let error as RPCErrorConvertible { + return StreamingServerResponse(error: RPCError(error)) } catch let other { let error = RPCError(code: .unknown, message: "", cause: other) return StreamingServerResponse(error: error) diff --git a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift index 0a74ba96f..2b09af76b 100644 --- a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift +++ b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift @@ -29,6 +29,7 @@ struct ClientRPCExecutorTestHarness { private let server: ServerStreamHandler private let clientTransport: StreamCountingClientTransport private let serverTransport: StreamCountingServerTransport + private let interceptors: [any ClientInterceptor] var clientStreamsOpened: Int { self.clientTransport.streamsOpened @@ -42,8 +43,13 @@ struct ClientRPCExecutorTestHarness { self.serverTransport.acceptedStreamsCount } - init(transport: Transport = .inProcess, server: ServerStreamHandler) { + init( + transport: Transport = .inProcess, + server: ServerStreamHandler, + interceptors: [any ClientInterceptor] = [] + ) { self.server = server + self.interceptors = interceptors switch transport { case .inProcess: @@ -141,7 +147,7 @@ struct ClientRPCExecutorTestHarness { serializer: serializer, deserializer: deserializer, transport: self.clientTransport, - interceptors: [], + interceptors: self.interceptors, handler: handler ) diff --git a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift index 1d3e78d69..4639dedb0 100644 --- a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift @@ -268,4 +268,25 @@ final class ClientRPCExecutorTests: XCTestCase { } } } + + func testInterceptorErrorConversion() async throws { + struct CustomError: RPCErrorConvertible, Error { + var rpcErrorCode: RPCError.Code { .alreadyExists } + var rpcErrorMessage: String { "foobar" } + var rpcErrorMetadata: Metadata { ["error": "yes"] } + } + + let tester = ClientRPCExecutorTestHarness( + server: .echo, + interceptors: [.throwError(CustomError())] + ) + + try await tester.unary(request: ClientRequest(message: [])) { response in + XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in + XCTAssertEqual(error.code, .alreadyExists) + XCTAssertEqual(error.message, "foobar") + XCTAssertEqual(error.metadata, ["error": "yes"]) + } + } + } } diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift index 09982b20d..215584ebf 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift @@ -374,4 +374,23 @@ final class ServerRPCExecutorTests: XCTestCase { ) } } + + func testInterceptorErrorConversion() async throws { + struct CustomError: RPCErrorConvertible, Error { + var rpcErrorCode: RPCError.Code { .alreadyExists } + var rpcErrorMessage: String { "foobar" } + var rpcErrorMetadata: Metadata { ["error": "yes"] } + } + + let harness = ServerRPCExecutorTestHarness(interceptors: [.throwError(CustomError())]) + try await harness.execute(handler: .throwing(CustomError())) { inbound in + try await inbound.write(.metadata(["foo": "bar"])) + await inbound.finish() + } consumer: { outbound in + let parts = try await outbound.collect() + let status = Status(code: .alreadyExists, message: "foobar") + let metadata: Metadata = ["error": "yes"] + XCTAssertEqual(parts, [.status(status, metadata)]) + } + } } diff --git a/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift b/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift index e13228b3e..ba6c1abf1 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift @@ -18,11 +18,11 @@ import GRPCCore extension ClientInterceptor where Self == RejectAllClientInterceptor { static func rejectAll(with error: RPCError) -> Self { - return RejectAllClientInterceptor(error: error, throw: false) + return RejectAllClientInterceptor(reject: error) } - static func throwError(_ error: RPCError) -> Self { - return RejectAllClientInterceptor(error: error, throw: true) + static func throwError(_ error: any Error) -> Self { + return RejectAllClientInterceptor(throw: error) } } @@ -35,15 +35,21 @@ extension ClientInterceptor where Self == RequestCountingClientInterceptor { /// Rejects all RPCs with the provided error. struct RejectAllClientInterceptor: ClientInterceptor { - /// The error to reject all RPCs with. - let error: RPCError - /// Whether the error should be thrown. If `false` then the request is rejected with the error - /// instead. - let `throw`: Bool + enum Mode: Sendable { + /// Throw the error rather. + case `throw`(any Error) + /// Reject the RPC with a given error. + case reject(RPCError) + } + + let mode: Mode + + init(throw error: any Error) { + self.mode = .throw(error) + } - init(error: RPCError, throw: Bool = false) { - self.error = error - self.`throw` = `throw` + init(reject error: RPCError) { + self.mode = .reject(error) } func intercept( @@ -54,10 +60,11 @@ struct RejectAllClientInterceptor: ClientInterceptor { ClientContext ) async throws -> StreamingClientResponse ) async throws -> StreamingClientResponse { - if self.throw { - throw self.error - } else { - return StreamingClientResponse(error: self.error) + switch self.mode { + case .throw(let error): + throw error + case .reject(let error): + return StreamingClientResponse(error: error) } } } diff --git a/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift b/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift index fdb869d1c..8340aa130 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift @@ -18,11 +18,11 @@ import GRPCCore extension ServerInterceptor where Self == RejectAllServerInterceptor { static func rejectAll(with error: RPCError) -> Self { - return RejectAllServerInterceptor(error: error, throw: false) + return RejectAllServerInterceptor(reject: error) } - static func throwError(_ error: RPCError) -> Self { - RejectAllServerInterceptor(error: error, throw: true) + static func throwError(_ error: any Error) -> Self { + RejectAllServerInterceptor(throw: error) } } @@ -34,15 +34,21 @@ extension ServerInterceptor where Self == RequestCountingServerInterceptor { /// Rejects all RPCs with the provided error. struct RejectAllServerInterceptor: ServerInterceptor { - /// The error to reject all RPCs with. - let error: RPCError - /// Whether the error should be thrown. If `false` then the request is rejected with the error - /// instead. - let `throw`: Bool + enum Mode: Sendable { + /// Throw the error rather. + case `throw`(any Error) + /// Reject the RPC with a given error. + case reject(RPCError) + } + + let mode: Mode + + init(throw error: any Error) { + self.mode = .throw(error) + } - init(error: RPCError, throw: Bool = false) { - self.error = error - self.`throw` = `throw` + init(reject error: RPCError) { + self.mode = .reject(error) } func intercept( @@ -53,10 +59,11 @@ struct RejectAllServerInterceptor: ServerInterceptor { ServerContext ) async throws -> StreamingServerResponse ) async throws -> StreamingServerResponse { - if self.throw { - throw self.error - } else { - return StreamingServerResponse(error: self.error) + switch self.mode { + case .throw(let error): + throw error + case .reject(let error): + return StreamingServerResponse(error: error) } } }