Skip to content

Commit ae9a5f2

Browse files
committed
Convert errors thrown from interceptor inbound or outbound stream
1 parent b0dda37 commit ae9a5f2

File tree

8 files changed

+225
-3
lines changed

8 files changed

+225
-3
lines changed

Sources/GRPCCore/Call/Client/Internal/ClientResponse+Convenience.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ extension ClientResponse {
7171
} catch let error as RPCError {
7272
// Known error type.
7373
self.accepted = .success(Contents(metadata: contents.metadata, error: error))
74+
} catch let error as any RPCErrorConvertible {
75+
self.accepted = .success(Contents(metadata: contents.metadata, error: RPCError(error)))
7476
} catch {
7577
// Unexpected, but should be handled nonetheless.
7678
self.accepted = .failure(RPCError(code: .unknown, message: String(describing: error)))

Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ internal enum ClientStreamExecutor {
9696
try await stream.write(.metadata(request.metadata))
9797
try await request.producer(.map(into: stream) { .message(try serializer.serialize($0)) })
9898
}.castError(to: RPCError.self) { other in
99-
RPCError(code: .unknown, message: "Write failed.", cause: other)
99+
if let convertible = other as? any RPCErrorConvertible {
100+
RPCError(convertible)
101+
} else {
102+
RPCError(code: .unknown, message: "Write failed.", cause: other)
103+
}
100104
}
101105

102106
switch result {

Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,11 @@ struct ServerRPCExecutor {
214214
.serializingToRPCResponsePart(into: outbound, with: serializer)
215215
)
216216
}.castError(to: RPCError.self) { error in
217-
RPCError(code: .unknown, message: "", cause: error)
217+
if let convertible = error as? (any RPCErrorConvertible) {
218+
return RPCError(convertible)
219+
} else {
220+
return RPCError(code: .unknown, message: "", cause: error)
221+
}
218222
}
219223

220224
switch result {

Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ struct ClientRPCExecutorTestHarness {
132132
try await withThrowingTaskGroup(of: Void.self) { group in
133133
group.addTask {
134134
try await self.serverTransport.listen { stream, context in
135-
try? await self.server.handle(stream: stream)
135+
do {
136+
try await self.server.handle(stream: stream)
137+
} catch {
138+
await stream.outbound.finish(throwing: error)
139+
}
136140
}
137141
}
138142

Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTests.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,4 +290,46 @@ final class ClientRPCExecutorTests: XCTestCase {
290290
}
291291
}
292292
}
293+
294+
func testInterceptorProducerErrorConversion() async throws {
295+
struct CustomError: RPCErrorConvertible, Error {
296+
var rpcErrorCode: RPCError.Code { .alreadyExists }
297+
var rpcErrorMessage: String { "foobar" }
298+
var rpcErrorMetadata: Metadata { ["error": "yes"] }
299+
}
300+
301+
let tester = ClientRPCExecutorTestHarness(
302+
server: .echo,
303+
interceptors: [.throwInProducer(CustomError())]
304+
)
305+
306+
try await tester.unary(request: ClientRequest(message: [])) { response in
307+
XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in
308+
XCTAssertEqual(error.code, .alreadyExists)
309+
XCTAssertEqual(error.message, "foobar")
310+
XCTAssertEqual(error.metadata, ["error": "yes"])
311+
}
312+
}
313+
}
314+
315+
func testInterceptorBodyPartsErrorConversion() async throws {
316+
struct CustomError: RPCErrorConvertible, Error {
317+
var rpcErrorCode: RPCError.Code { .alreadyExists }
318+
var rpcErrorMessage: String { "foobar" }
319+
var rpcErrorMetadata: Metadata { ["error": "yes"] }
320+
}
321+
322+
let tester = ClientRPCExecutorTestHarness(
323+
server: .echo,
324+
interceptors: [.throwInBodyParts(CustomError())]
325+
)
326+
327+
try await tester.unary(request: ClientRequest(message: [])) { response in
328+
XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in
329+
XCTAssertEqual(error.code, .alreadyExists)
330+
XCTAssertEqual(error.message, "foobar")
331+
XCTAssertEqual(error.metadata, ["error": "yes"])
332+
}
333+
}
334+
}
293335
}

Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,4 +394,48 @@ final class ServerRPCExecutorTests: XCTestCase {
394394
XCTAssertEqual(parts, [.status(status, metadata)])
395395
}
396396
}
397+
398+
func testInterceptorProducerErrorConversion() async throws {
399+
struct CustomError: RPCErrorConvertible, Error {
400+
var rpcErrorCode: RPCError.Code { .alreadyExists }
401+
var rpcErrorMessage: String { "foobar" }
402+
var rpcErrorMetadata: Metadata { ["error": "yes"] }
403+
}
404+
405+
let harness = ServerRPCExecutorTestHarness(
406+
interceptors: [.throwInProducer(CustomError(), after: .milliseconds(10))]
407+
)
408+
try await harness.execute(handler: .echo) { inbound in
409+
try await inbound.write(.metadata(["foo": "bar"]))
410+
try await inbound.write(.message([0]))
411+
try await Task.sleep(for: .milliseconds(50))
412+
try await inbound.write(.message([1]))
413+
await inbound.finish()
414+
} consumer: { outbound in
415+
let parts = try await outbound.collect()
416+
let status = Status(code: .alreadyExists, message: "foobar")
417+
let metadata: Metadata = ["error": "yes"]
418+
XCTAssertEqual(parts, [.metadata(["foo": "bar"]), .message([0]), .status(status, metadata)])
419+
}
420+
}
421+
422+
func testInterceptorMessagesErrorConversion() async throws {
423+
struct CustomError: RPCErrorConvertible, Error {
424+
var rpcErrorCode: RPCError.Code { .alreadyExists }
425+
var rpcErrorMessage: String { "foobar" }
426+
var rpcErrorMetadata: Metadata { ["error": "yes"] }
427+
}
428+
429+
let harness = ServerRPCExecutorTestHarness(interceptors: [.throwInMessageSequence(CustomError())])
430+
try await harness.execute(handler: .echo) { inbound in
431+
try await inbound.write(.metadata(["foo": "bar"]))
432+
try await inbound.write(.message([0])) // the sequence throws instantly, this should not arrive
433+
await inbound.finish()
434+
} consumer: { outbound in
435+
let parts = try await outbound.collect()
436+
let status = Status(code: .alreadyExists, message: "foobar")
437+
let metadata: Metadata = ["error": "yes"]
438+
XCTAssertEqual(parts, [.metadata(["foo": "bar"]), .status(status, metadata)])
439+
}
440+
}
397441
}

Tests/GRPCCoreTests/Test Utilities/Call/Client/ClientInterceptors.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ extension ClientInterceptor where Self == RejectAllClientInterceptor {
2626
return RejectAllClientInterceptor(throw: error)
2727
}
2828

29+
static func throwInBodyParts(_ error: any Error) -> Self {
30+
return RejectAllClientInterceptor(throwInBodyParts: error)
31+
}
32+
33+
static func throwInProducer(_ error: any Error) -> Self {
34+
return RejectAllClientInterceptor(throwInProducer: error)
35+
}
2936
}
3037

3138
@available(gRPCSwift 2.0, *)
@@ -43,6 +50,10 @@ struct RejectAllClientInterceptor: ClientInterceptor {
4350
case `throw`(any Error)
4451
/// Reject the RPC with a given error.
4552
case reject(RPCError)
53+
/// Throw an error in the body parts sequence.
54+
case throwInBodyParts(any Error)
55+
/// Throw an error in the message producer closure.
56+
case throwInProducer(any Error)
4657
}
4758

4859
let mode: Mode
@@ -55,6 +66,14 @@ struct RejectAllClientInterceptor: ClientInterceptor {
5566
self.mode = .reject(error)
5667
}
5768

69+
init(throwInBodyParts error: any Error) {
70+
self.mode = .throwInBodyParts(error)
71+
}
72+
73+
init(throwInProducer error: any Error) {
74+
self.mode = .throwInProducer(error)
75+
}
76+
5877
func intercept<Input: Sendable, Output: Sendable>(
5978
request: StreamingClientRequest<Input>,
6079
context: ClientContext,
@@ -68,6 +87,29 @@ struct RejectAllClientInterceptor: ClientInterceptor {
6887
throw error
6988
case .reject(let error):
7089
return StreamingClientResponse(error: error)
90+
case .throwInBodyParts(let error):
91+
var response = try await next(request, context)
92+
switch response.accepted {
93+
case .success(var success):
94+
let stream = AsyncThrowingStream<StreamingClientResponse<Output>.Contents.BodyPart, any Error>.makeStream()
95+
stream.continuation.finish(throwing: error)
96+
97+
success.bodyParts = RPCAsyncSequence(wrapping: stream.stream)
98+
response.accepted = .success(success)
99+
return response
100+
case .failure:
101+
return response
102+
}
103+
case .throwInProducer(let error):
104+
let wrappedProducer = request.producer
105+
106+
var request = request
107+
request.producer = { writer in
108+
try await wrappedProducer(writer)
109+
throw error
110+
}
111+
112+
return try await next(request, context)
71113
}
72114
}
73115
}

Tests/GRPCCoreTests/Test Utilities/Call/Server/ServerInterceptors.swift

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ extension ServerInterceptor where Self == RejectAllServerInterceptor {
2525
static func throwError(_ error: any Error) -> Self {
2626
RejectAllServerInterceptor(throw: error)
2727
}
28+
29+
static func throwInProducer(_ error: any Error, after duration: Duration) -> Self {
30+
RejectAllServerInterceptor(throwInProducer: error, after: duration)
31+
}
32+
33+
static func throwInMessageSequence(_ error: any Error) -> Self {
34+
RejectAllServerInterceptor(throwInMessageSequence: error)
35+
}
2836
}
2937

3038
@available(gRPCSwift 2.0, *)
@@ -42,6 +50,16 @@ struct RejectAllServerInterceptor: ServerInterceptor {
4250
case `throw`(any Error)
4351
/// Reject the RPC with a given error.
4452
case reject(RPCError)
53+
/// Throw in the producer closure returned.
54+
case throwInProducer(any Error, after: Duration)
55+
/// Throw in the async sequence that stream inbound messages.
56+
case throwInMessageSequence(any Error)
57+
}
58+
59+
private enum TimeoutResult {
60+
case `throw`(any Error)
61+
case cancelled
62+
case result(Metadata)
4563
}
4664

4765
let mode: Mode
@@ -54,6 +72,14 @@ struct RejectAllServerInterceptor: ServerInterceptor {
5472
self.mode = .reject(error)
5573
}
5674

75+
init(throwInProducer error: any Error, after duration: Duration) {
76+
self.mode = .throwInProducer(error, after: duration)
77+
}
78+
79+
init(throwInMessageSequence error: any Error) {
80+
self.mode = .throwInMessageSequence(error)
81+
}
82+
5783
func intercept<Input: Sendable, Output: Sendable>(
5884
request: StreamingServerRequest<Input>,
5985
context: ServerContext,
@@ -67,6 +93,60 @@ struct RejectAllServerInterceptor: ServerInterceptor {
6793
throw error
6894
case .reject(let error):
6995
return StreamingServerResponse(error: error)
96+
case .throwInProducer(let error, let duration):
97+
var response = try await next(request, context)
98+
switch response.accepted {
99+
case .success(var success):
100+
let wrappedProducer = success.producer
101+
success.producer = { writer in
102+
let result: Result<Metadata, any Error> = await withTaskGroup(of: TimeoutResult.self) { group in
103+
group.addTask {
104+
do {
105+
try await Task.sleep(for: duration, tolerance: .nanoseconds(1))
106+
} catch {
107+
return .cancelled
108+
}
109+
return .throw(error)
110+
}
111+
112+
group.addTask {
113+
do {
114+
return .result(try await wrappedProducer(writer))
115+
} catch {
116+
return .throw(error)
117+
}
118+
}
119+
120+
let first = await group.next()!
121+
group.cancelAll()
122+
let second = await group.next()!
123+
124+
switch (first, second) {
125+
case (.throw(let error), _):
126+
return .failure(error)
127+
case (.result(let metadata), _):
128+
return .success(metadata)
129+
case (.cancelled, _):
130+
return .failure(CancellationError())
131+
}
132+
}
133+
134+
return try result.get()
135+
}
136+
137+
response.accepted = .success(success)
138+
return response
139+
case .failure:
140+
return response
141+
}
142+
case .throwInMessageSequence(let error):
143+
let stream = AsyncThrowingStream<Input, any Error>.makeStream()
144+
stream.continuation.finish(throwing: error)
145+
146+
var request = request
147+
request.messages = RPCAsyncSequence(wrapping: stream.stream)
148+
149+
return try await next(request, context)
70150
}
71151
}
72152
}

0 commit comments

Comments
 (0)