diff --git a/Sources/GRPCCore/Call/Client/Internal/ClientResponse+Convenience.swift b/Sources/GRPCCore/Call/Client/Internal/ClientResponse+Convenience.swift index 41c3d024..2b2abbab 100644 --- a/Sources/GRPCCore/Call/Client/Internal/ClientResponse+Convenience.swift +++ b/Sources/GRPCCore/Call/Client/Internal/ClientResponse+Convenience.swift @@ -71,6 +71,8 @@ extension ClientResponse { } catch let error as RPCError { // Known error type. self.accepted = .success(Contents(metadata: contents.metadata, error: error)) + } catch let error as any RPCErrorConvertible { + self.accepted = .success(Contents(metadata: contents.metadata, error: RPCError(error))) } catch { // Unexpected, but should be handled nonetheless. self.accepted = .failure(RPCError(code: .unknown, message: String(describing: error))) diff --git a/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift b/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift index 74aac103..67bcc0a5 100644 --- a/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift +++ b/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift @@ -26,7 +26,7 @@ internal enum ClientStreamExecutor { /// - attempt: The attempt number for the RPC that will be executed. /// - serializer: A request serializer. /// - deserializer: A response deserializer. - /// - stream: The stream to excecute the RPC on. + /// - stream: The stream to execute the RPC on. /// - Returns: A streamed response. @inlinable static func execute( @@ -95,7 +95,7 @@ internal enum ClientStreamExecutor { let result = await Result { try await stream.write(.metadata(request.metadata)) try await request.producer(.map(into: stream) { .message(try serializer.serialize($0)) }) - }.castError(to: RPCError.self) { other in + }.castOrConvertRPCError { other in RPCError(code: .unknown, message: "Write failed.", cause: other) } diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift index 8df70f86..96c929a4 100644 --- a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift @@ -188,16 +188,12 @@ struct ServerRPCExecutor { ) { request, context in try await handler(request, context) } - }.castError(to: RPCError.self) { error in - if let convertible = error as? (any RPCErrorConvertible) { - return RPCError(convertible) - } else { - return RPCError( - code: .unknown, - message: "Service method threw an unknown error.", - cause: error - ) - } + }.castOrConvertRPCError { error in + RPCError( + code: .unknown, + message: "Service method threw an unknown error.", + cause: error + ) }.flatMap { response in response.accepted } @@ -213,7 +209,7 @@ struct ServerRPCExecutor { return try await contents.producer( .serializingToRPCResponsePart(into: outbound, with: serializer) ) - }.castError(to: RPCError.self) { error in + }.castOrConvertRPCError { error in RPCError(code: .unknown, message: "", cause: error) } diff --git a/Sources/GRPCCore/Internal/Result+Catching.swift b/Sources/GRPCCore/Internal/Result+Catching.swift index 07258ca5..a3a0e358 100644 --- a/Sources/GRPCCore/Internal/Result+Catching.swift +++ b/Sources/GRPCCore/Internal/Result+Catching.swift @@ -45,4 +45,23 @@ extension Result { return (error as? NewError) ?? buildError(error) } } + + /// Attempt to map or convert the error to an `RPCError`. + /// + /// If the cast or conversion is not possible then the provided closure is used to create an error of the given type. + /// + /// - Parameter buildError: A closure which constructs the desired error if conversion is not possible. + @inlinable + @available(gRPCSwift 2.0, *) + func castOrConvertRPCError( + or buildError: (any Error) -> RPCError + ) -> Result { + return self.castError(to: RPCError.self) { error in + if let convertible = error as? any RPCErrorConvertible { + return RPCError(convertible) + } else { + return buildError(error) + } + } + } } diff --git a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift index 18a5d06c..770c92d8 100644 --- a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift +++ b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift @@ -132,7 +132,11 @@ struct ClientRPCExecutorTestHarness { try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.serverTransport.listen { stream, context in - try? await self.server.handle(stream: stream) + do { + try await self.server.handle(stream: stream) + } catch { + await stream.outbound.finish(throwing: error) + } } } diff --git a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift index 474640f9..46e2cf75 100644 --- a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift @@ -290,4 +290,46 @@ final class ClientRPCExecutorTests: XCTestCase { } } } + + func testInterceptorProducerErrorConversion() async throws { + struct CustomError: RPCErrorConvertible, Error { + var rpcErrorCode: RPCError.Code { .alreadyExists } + var rpcErrorMessage: String { "foobar" } + var rpcErrorMetadata: Metadata { ["error": "yes"] } + } + + let tester = ClientRPCExecutorTestHarness( + server: .echo, + interceptors: [.throwInProducer(CustomError())] + ) + + try await tester.unary(request: ClientRequest(message: [])) { response in + XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in + XCTAssertEqual(error.code, .alreadyExists) + XCTAssertEqual(error.message, "foobar") + XCTAssertEqual(error.metadata, ["error": "yes"]) + } + } + } + + func testInterceptorBodyPartsErrorConversion() async throws { + struct CustomError: RPCErrorConvertible, Error { + var rpcErrorCode: RPCError.Code { .alreadyExists } + var rpcErrorMessage: String { "foobar" } + var rpcErrorMetadata: Metadata { ["error": "yes"] } + } + + let tester = ClientRPCExecutorTestHarness( + server: .echo, + interceptors: [.throwInBodyParts(CustomError())] + ) + + try await tester.unary(request: ClientRequest(message: [])) { response in + XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in + XCTAssertEqual(error.code, .alreadyExists) + XCTAssertEqual(error.message, "foobar") + XCTAssertEqual(error.metadata, ["error": "yes"]) + } + } + } } diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift index 73f9ba82..d094445b 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift @@ -394,4 +394,48 @@ final class ServerRPCExecutorTests: XCTestCase { XCTAssertEqual(parts, [.status(status, metadata)]) } } + + func testInterceptorProducerErrorConversion() async throws { + struct CustomError: RPCErrorConvertible, Error { + var rpcErrorCode: RPCError.Code { .alreadyExists } + var rpcErrorMessage: String { "foobar" } + var rpcErrorMetadata: Metadata { ["error": "yes"] } + } + + let harness = ServerRPCExecutorTestHarness( + interceptors: [.throwInProducer(CustomError())] + ) + try await harness.execute(handler: .echo) { inbound in + try await inbound.write(.metadata(["foo": "bar"])) + try await inbound.write(.message([0])) + } consumer: { outbound in + let parts = try await outbound.collect() + let status = Status(code: .alreadyExists, message: "foobar") + let metadata: Metadata = ["error": "yes"] + XCTAssertEqual(parts, [.metadata(["foo": "bar"]), .message([0]), .status(status, metadata)]) + } + } + + func testInterceptorMessagesErrorConversion() async throws { + struct CustomError: RPCErrorConvertible, Error { + var rpcErrorCode: RPCError.Code { .alreadyExists } + var rpcErrorMessage: String { "foobar" } + var rpcErrorMetadata: Metadata { ["error": "yes"] } + } + + let harness = ServerRPCExecutorTestHarness(interceptors: [ + .throwInMessageSequence(CustomError()) + ]) + try await harness.execute(handler: .echo) { inbound in + try await inbound.write(.metadata(["foo": "bar"])) + // the sequence throws instantly, this should not arrive + try await inbound.write(.message([0])) + await inbound.finish() + } consumer: { outbound in + let parts = try await outbound.collect() + let status = Status(code: .alreadyExists, message: "foobar") + let metadata: Metadata = ["error": "yes"] + XCTAssertEqual(parts, [.metadata(["foo": "bar"]), .status(status, metadata)]) + } + } } diff --git a/Tests/GRPCCoreTests/Internal/Result+CatchingTests.swift b/Tests/GRPCCoreTests/Internal/Result+CatchingTests.swift index 644bc72d..ce12f366 100644 --- a/Tests/GRPCCoreTests/Internal/Result+CatchingTests.swift +++ b/Tests/GRPCCoreTests/Internal/Result+CatchingTests.swift @@ -63,4 +63,39 @@ final class ResultCatchingTests: XCTestCase { XCTAssertEqual(error, RPCError(code: .invalidArgument, message: "fallback")) } } + + func testCastOrConvertRPCErrorConvertible() { + struct ConvertibleError: Error, RPCErrorConvertible { + let rpcErrorCode: RPCError.Code = .unknown + let rpcErrorMessage = "foo" + } + + let result = Result.failure(ConvertibleError()) + let typedFailure = result.castOrConvertRPCError { _ in + XCTFail("buildError(_:) was called") + return RPCError(code: .failedPrecondition, message: "shouldn't happen") + } + + switch typedFailure { + case .success: + XCTFail() + case .failure(let error): + XCTAssertEqual(error, RPCError(code: .unknown, message: "foo")) + } + } + + func testCastOrConvertToErrorOfIncorrectType() async { + struct WrongError: Error {} + let result = Result.failure(WrongError()) + let typedFailure = result.castOrConvertRPCError { _ in + return RPCError(code: .invalidArgument, message: "fallback") + } + + switch typedFailure { + case .success: + XCTFail() + case .failure(let error): + XCTAssertEqual(error, RPCError(code: .invalidArgument, message: "fallback")) + } + } } diff --git a/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift b/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift index 3c46a35d..95f18edd 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift @@ -26,6 +26,13 @@ extension ClientInterceptor where Self == RejectAllClientInterceptor { return RejectAllClientInterceptor(throw: error) } + static func throwInBodyParts(_ error: any Error) -> Self { + return RejectAllClientInterceptor(throwInBodyParts: error) + } + + static func throwInProducer(_ error: any Error) -> Self { + return RejectAllClientInterceptor(throwInProducer: error) + } } @available(gRPCSwift 2.0, *) @@ -43,6 +50,10 @@ struct RejectAllClientInterceptor: ClientInterceptor { case `throw`(any Error) /// Reject the RPC with a given error. case reject(RPCError) + /// Throw an error in the body parts sequence. + case throwInBodyParts(any Error) + /// Throw an error in the message producer closure. + case throwInProducer(any Error) } let mode: Mode @@ -55,6 +66,14 @@ struct RejectAllClientInterceptor: ClientInterceptor { self.mode = .reject(error) } + init(throwInBodyParts error: any Error) { + self.mode = .throwInBodyParts(error) + } + + init(throwInProducer error: any Error) { + self.mode = .throwInProducer(error) + } + func intercept( request: StreamingClientRequest, context: ClientContext, @@ -68,6 +87,31 @@ struct RejectAllClientInterceptor: ClientInterceptor { throw error case .reject(let error): return StreamingClientResponse(error: error) + case .throwInBodyParts(let error): + var response = try await next(request, context) + switch response.accepted { + case .success(var success): + let stream = AsyncThrowingStream< + StreamingClientResponse.Contents.BodyPart, any Error + >.makeStream() + stream.continuation.finish(throwing: error) + + success.bodyParts = RPCAsyncSequence(wrapping: stream.stream) + response.accepted = .success(success) + return response + case .failure: + return response + } + case .throwInProducer(let error): + let wrappedProducer = request.producer + + var request = request + request.producer = { writer in + try await wrappedProducer(writer) + throw error + } + + return try await next(request, context) } } } diff --git a/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift b/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift index 5918102d..1372c6d8 100644 --- a/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift +++ b/Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift @@ -25,6 +25,14 @@ extension ServerInterceptor where Self == RejectAllServerInterceptor { static func throwError(_ error: any Error) -> Self { RejectAllServerInterceptor(throw: error) } + + static func throwInProducer(_ error: any Error) -> Self { + RejectAllServerInterceptor(throwInProducer: error) + } + + static func throwInMessageSequence(_ error: any Error) -> Self { + RejectAllServerInterceptor(throwInMessageSequence: error) + } } @available(gRPCSwift 2.0, *) @@ -42,6 +50,16 @@ struct RejectAllServerInterceptor: ServerInterceptor { case `throw`(any Error) /// Reject the RPC with a given error. case reject(RPCError) + /// Throw in the producer closure returned. + case throwInProducer(any Error) + /// Throw in the async sequence that stream inbound messages. + case throwInMessageSequence(any Error) + } + + private enum TimeoutResult { + case `throw`(any Error) + case cancelled + case result(Metadata) } let mode: Mode @@ -54,6 +72,14 @@ struct RejectAllServerInterceptor: ServerInterceptor { self.mode = .reject(error) } + init(throwInProducer error: any Error) { + self.mode = .throwInProducer(error) + } + + init(throwInMessageSequence error: any Error) { + self.mode = .throwInMessageSequence(error) + } + func intercept( request: StreamingServerRequest, context: ServerContext, @@ -67,6 +93,36 @@ struct RejectAllServerInterceptor: ServerInterceptor { throw error case .reject(let error): return StreamingServerResponse(error: error) + case .throwInProducer(let error): + var response = try await next(request, context) + switch response.accepted { + case .success(var success): + let wrappedProducer = success.producer + success.producer = { writer in + try await withThrowingTaskGroup(of: Metadata.self) { group in + group.addTask { + try await wrappedProducer(writer) + } + + group.cancelAll() + _ = try await group.next()! + throw error + } + } + + response.accepted = .success(success) + return response + case .failure: + return response + } + case .throwInMessageSequence(let error): + let stream = AsyncThrowingStream.makeStream() + stream.continuation.finish(throwing: error) + + var request = request + request.messages = RPCAsyncSequence(wrapping: stream.stream) + + return try await next(request, context) } } }