Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ extension ClientResponse {
} catch let error as RPCError {
// Known error type.
self.accepted = .success(Contents(metadata: contents.metadata, error: error))
} catch let error as any RPCErrorConvertible {
self.accepted = .success(Contents(metadata: contents.metadata, error: RPCError(error)))
} catch {
// Unexpected, but should be handled nonetheless.
self.accepted = .failure(RPCError(code: .unknown, message: String(describing: error)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal enum ClientStreamExecutor {
/// - attempt: The attempt number for the RPC that will be executed.
/// - serializer: A request serializer.
/// - deserializer: A response deserializer.
/// - stream: The stream to excecute the RPC on.
/// - stream: The stream to execute the RPC on.
/// - Returns: A streamed response.
@inlinable
static func execute<Input: Sendable, Output: Sendable, Bytes: GRPCContiguousBytes>(
Expand Down Expand Up @@ -95,7 +95,7 @@ internal enum ClientStreamExecutor {
let result = await Result {
try await stream.write(.metadata(request.metadata))
try await request.producer(.map(into: stream) { .message(try serializer.serialize($0)) })
}.castError(to: RPCError.self) { other in
}.castOrConvertRPCError { other in
RPCError(code: .unknown, message: "Write failed.", cause: other)
}

Expand Down
18 changes: 7 additions & 11 deletions Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,12 @@ struct ServerRPCExecutor {
) { request, context in
try await handler(request, context)
}
}.castError(to: RPCError.self) { error in
if let convertible = error as? (any RPCErrorConvertible) {
return RPCError(convertible)
} else {
return RPCError(
code: .unknown,
message: "Service method threw an unknown error.",
cause: error
)
}
}.castOrConvertRPCError { error in
RPCError(
code: .unknown,
message: "Service method threw an unknown error.",
cause: error
)
}.flatMap { response in
response.accepted
}
Expand All @@ -213,7 +209,7 @@ struct ServerRPCExecutor {
return try await contents.producer(
.serializingToRPCResponsePart(into: outbound, with: serializer)
)
}.castError(to: RPCError.self) { error in
}.castOrConvertRPCError { error in
RPCError(code: .unknown, message: "", cause: error)
}

Expand Down
19 changes: 19 additions & 0 deletions Sources/GRPCCore/Internal/Result+Catching.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,23 @@ extension Result {
return (error as? NewError) ?? buildError(error)
}
}

/// Attempt to map or convert the error to an `RPCError`.
///
/// If the cast or conversion is not possible then the provided closure is used to create an error of the given type.
///
/// - Parameter buildError: A closure which constructs the desired error if conversion is not possible.
@inlinable
@available(gRPCSwift 2.0, *)
func castOrConvertRPCError(
or buildError: (any Error) -> RPCError
) -> Result<Success, RPCError> {
return self.castError(to: RPCError.self) { error in
if let convertible = error as? any RPCErrorConvertible {
return RPCError(convertible)
} else {
return buildError(error)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ struct ClientRPCExecutorTestHarness {
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await self.serverTransport.listen { stream, context in
try? await self.server.handle(stream: stream)
do {
try await self.server.handle(stream: stream)
} catch {
await stream.outbound.finish(throwing: error)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,46 @@ final class ClientRPCExecutorTests: XCTestCase {
}
}
}

func testInterceptorProducerErrorConversion() async throws {
struct CustomError: RPCErrorConvertible, Error {
var rpcErrorCode: RPCError.Code { .alreadyExists }
var rpcErrorMessage: String { "foobar" }
var rpcErrorMetadata: Metadata { ["error": "yes"] }
}

let tester = ClientRPCExecutorTestHarness(
server: .echo,
interceptors: [.throwInProducer(CustomError())]
)

try await tester.unary(request: ClientRequest(message: [])) { response in
XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in
XCTAssertEqual(error.code, .alreadyExists)
XCTAssertEqual(error.message, "foobar")
XCTAssertEqual(error.metadata, ["error": "yes"])
}
}
}

func testInterceptorBodyPartsErrorConversion() async throws {
struct CustomError: RPCErrorConvertible, Error {
var rpcErrorCode: RPCError.Code { .alreadyExists }
var rpcErrorMessage: String { "foobar" }
var rpcErrorMetadata: Metadata { ["error": "yes"] }
}

let tester = ClientRPCExecutorTestHarness(
server: .echo,
interceptors: [.throwInBodyParts(CustomError())]
)

try await tester.unary(request: ClientRequest(message: [])) { response in
XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in
XCTAssertEqual(error.code, .alreadyExists)
XCTAssertEqual(error.message, "foobar")
XCTAssertEqual(error.metadata, ["error": "yes"])
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -394,4 +394,48 @@ final class ServerRPCExecutorTests: XCTestCase {
XCTAssertEqual(parts, [.status(status, metadata)])
}
}

func testInterceptorProducerErrorConversion() async throws {
struct CustomError: RPCErrorConvertible, Error {
var rpcErrorCode: RPCError.Code { .alreadyExists }
var rpcErrorMessage: String { "foobar" }
var rpcErrorMetadata: Metadata { ["error": "yes"] }
}

let harness = ServerRPCExecutorTestHarness(
interceptors: [.throwInProducer(CustomError())]
)
try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.metadata(["foo": "bar"]))
try await inbound.write(.message([0]))
} consumer: { outbound in
let parts = try await outbound.collect()
let status = Status(code: .alreadyExists, message: "foobar")
let metadata: Metadata = ["error": "yes"]
XCTAssertEqual(parts, [.metadata(["foo": "bar"]), .message([0]), .status(status, metadata)])
}
}

func testInterceptorMessagesErrorConversion() async throws {
struct CustomError: RPCErrorConvertible, Error {
var rpcErrorCode: RPCError.Code { .alreadyExists }
var rpcErrorMessage: String { "foobar" }
var rpcErrorMetadata: Metadata { ["error": "yes"] }
}

let harness = ServerRPCExecutorTestHarness(interceptors: [
.throwInMessageSequence(CustomError())
])
try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.metadata(["foo": "bar"]))
// the sequence throws instantly, this should not arrive
try await inbound.write(.message([0]))
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
let status = Status(code: .alreadyExists, message: "foobar")
let metadata: Metadata = ["error": "yes"]
XCTAssertEqual(parts, [.metadata(["foo": "bar"]), .status(status, metadata)])
}
}
}
35 changes: 35 additions & 0 deletions Tests/GRPCCoreTests/Internal/Result+CatchingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,39 @@ final class ResultCatchingTests: XCTestCase {
XCTAssertEqual(error, RPCError(code: .invalidArgument, message: "fallback"))
}
}

func testCastOrConvertRPCErrorConvertible() {
struct ConvertibleError: Error, RPCErrorConvertible {
let rpcErrorCode: RPCError.Code = .unknown
let rpcErrorMessage = "foo"
}

let result = Result<Void, any Error>.failure(ConvertibleError())
let typedFailure = result.castOrConvertRPCError { _ in
XCTFail("buildError(_:) was called")
return RPCError(code: .failedPrecondition, message: "shouldn't happen")
}

switch typedFailure {
case .success:
XCTFail()
case .failure(let error):
XCTAssertEqual(error, RPCError(code: .unknown, message: "foo"))
}
}

func testCastOrConvertToErrorOfIncorrectType() async {
struct WrongError: Error {}
let result = Result<Void, any Error>.failure(WrongError())
let typedFailure = result.castOrConvertRPCError { _ in
return RPCError(code: .invalidArgument, message: "fallback")
}

switch typedFailure {
case .success:
XCTFail()
case .failure(let error):
XCTAssertEqual(error, RPCError(code: .invalidArgument, message: "fallback"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ extension ClientInterceptor where Self == RejectAllClientInterceptor {
return RejectAllClientInterceptor(throw: error)
}

static func throwInBodyParts(_ error: any Error) -> Self {
return RejectAllClientInterceptor(throwInBodyParts: error)
}

static func throwInProducer(_ error: any Error) -> Self {
return RejectAllClientInterceptor(throwInProducer: error)
}
}

@available(gRPCSwift 2.0, *)
Expand All @@ -43,6 +50,10 @@ struct RejectAllClientInterceptor: ClientInterceptor {
case `throw`(any Error)
/// Reject the RPC with a given error.
case reject(RPCError)
/// Throw an error in the body parts sequence.
case throwInBodyParts(any Error)
/// Throw an error in the message producer closure.
case throwInProducer(any Error)
}

let mode: Mode
Expand All @@ -55,6 +66,14 @@ struct RejectAllClientInterceptor: ClientInterceptor {
self.mode = .reject(error)
}

init(throwInBodyParts error: any Error) {
self.mode = .throwInBodyParts(error)
}

init(throwInProducer error: any Error) {
self.mode = .throwInProducer(error)
}

func intercept<Input: Sendable, Output: Sendable>(
request: StreamingClientRequest<Input>,
context: ClientContext,
Expand All @@ -68,6 +87,31 @@ struct RejectAllClientInterceptor: ClientInterceptor {
throw error
case .reject(let error):
return StreamingClientResponse(error: error)
case .throwInBodyParts(let error):
var response = try await next(request, context)
switch response.accepted {
case .success(var success):
let stream = AsyncThrowingStream<
StreamingClientResponse<Output>.Contents.BodyPart, any Error
>.makeStream()
stream.continuation.finish(throwing: error)

success.bodyParts = RPCAsyncSequence(wrapping: stream.stream)
response.accepted = .success(success)
return response
case .failure:
return response
}
case .throwInProducer(let error):
let wrappedProducer = request.producer

var request = request
request.producer = { writer in
try await wrappedProducer(writer)
throw error
}

return try await next(request, context)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ extension ServerInterceptor where Self == RejectAllServerInterceptor {
static func throwError(_ error: any Error) -> Self {
RejectAllServerInterceptor(throw: error)
}

static func throwInProducer(_ error: any Error) -> Self {
RejectAllServerInterceptor(throwInProducer: error)
}

static func throwInMessageSequence(_ error: any Error) -> Self {
RejectAllServerInterceptor(throwInMessageSequence: error)
}
}

@available(gRPCSwift 2.0, *)
Expand All @@ -42,6 +50,16 @@ struct RejectAllServerInterceptor: ServerInterceptor {
case `throw`(any Error)
/// Reject the RPC with a given error.
case reject(RPCError)
/// Throw in the producer closure returned.
case throwInProducer(any Error)
/// Throw in the async sequence that stream inbound messages.
case throwInMessageSequence(any Error)
}

private enum TimeoutResult {
case `throw`(any Error)
case cancelled
case result(Metadata)
}

let mode: Mode
Expand All @@ -54,6 +72,14 @@ struct RejectAllServerInterceptor: ServerInterceptor {
self.mode = .reject(error)
}

init(throwInProducer error: any Error) {
self.mode = .throwInProducer(error)
}

init(throwInMessageSequence error: any Error) {
self.mode = .throwInMessageSequence(error)
}

func intercept<Input: Sendable, Output: Sendable>(
request: StreamingServerRequest<Input>,
context: ServerContext,
Expand All @@ -67,6 +93,36 @@ struct RejectAllServerInterceptor: ServerInterceptor {
throw error
case .reject(let error):
return StreamingServerResponse(error: error)
case .throwInProducer(let error):
var response = try await next(request, context)
switch response.accepted {
case .success(var success):
let wrappedProducer = success.producer
success.producer = { writer in
try await withThrowingTaskGroup(of: Metadata.self) { group in
group.addTask {
try await wrappedProducer(writer)
}

group.cancelAll()
_ = try await group.next()!
throw error
}
}

response.accepted = .success(success)
return response
case .failure:
return response
}
case .throwInMessageSequence(let error):
let stream = AsyncThrowingStream<Input, any Error>.makeStream()
stream.continuation.finish(throwing: error)

var request = request
request.messages = RPCAsyncSequence(wrapping: stream.stream)

return try await next(request, context)
}
}
}
Expand Down