Skip to content

Commit 76d4ec1

Browse files
authored
Increase cancellation support for async calls (#1608)
Motivation: The async client calls have limiteed support for cancellation: they support it for the "wrapped" calls and request/response streams but not for metadata/status on the lower level call objects. Modifications: - Add support for Task cancellation on the async call types Result: Better cancellation support
1 parent dbd94fa commit 76d4ec1

File tree

7 files changed

+202
-15
lines changed

7 files changed

+202
-15
lines changed

Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,22 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
5252

5353
// MARK: - Response Parts
5454

55+
private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
56+
return try await withTaskCancellationHandler(operation: fn) {
57+
self.cancel()
58+
}
59+
}
60+
5561
/// The initial metadata returned from the server.
5662
///
5763
/// - Important: The initial metadata will only be available when the first response has been
5864
/// received. However, it is not necessary for the response to have been consumed before reading
5965
/// this property.
6066
public var initialMetadata: HPACKHeaders {
6167
get async throws {
62-
try await self.responseParts.initialMetadata.get()
68+
try await self.withRPCCancellation {
69+
try await self.responseParts.initialMetadata.get()
70+
}
6371
}
6472
}
6573

@@ -68,7 +76,9 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
6876
/// - Important: Awaiting this property will suspend until the responses have been consumed.
6977
public var trailingMetadata: HPACKHeaders {
7078
get async throws {
71-
try await self.responseParts.trailingMetadata.get()
79+
try await self.withRPCCancellation {
80+
try await self.responseParts.trailingMetadata.get()
81+
}
7282
}
7383
}
7484

@@ -78,7 +88,9 @@ public struct GRPCAsyncBidirectionalStreamingCall<Request: Sendable, Response: S
7888
public var status: GRPCStatus {
7989
get async {
8090
// force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
81-
try! await self.responseParts.status.get()
91+
await self.withRPCCancellation {
92+
try! await self.responseParts.status.get()
93+
}
8294
}
8395
}
8496

Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,29 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
4343

4444
// MARK: - Response Parts
4545

46+
private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
47+
return try await withTaskCancellationHandler(operation: fn) {
48+
self.cancel()
49+
}
50+
}
51+
4652
/// The initial metadata returned from the server.
4753
///
4854
/// - Important: The initial metadata will only be available when the response has been received.
4955
public var initialMetadata: HPACKHeaders {
5056
get async throws {
51-
try await self.responseParts.initialMetadata.get()
57+
return try await self.withRPCCancellation {
58+
try await self.responseParts.initialMetadata.get()
59+
}
5260
}
5361
}
5462

5563
/// The response returned by the server.
5664
public var response: Response {
5765
get async throws {
58-
try await self.responseParts.response.get()
66+
return try await self.withRPCCancellation {
67+
try await self.responseParts.response.get()
68+
}
5969
}
6070
}
6171

@@ -64,7 +74,9 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
6474
/// - Important: Awaiting this property will suspend until the responses have been consumed.
6575
public var trailingMetadata: HPACKHeaders {
6676
get async throws {
67-
try await self.responseParts.trailingMetadata.get()
77+
return try await self.withRPCCancellation {
78+
try await self.responseParts.trailingMetadata.get()
79+
}
6880
}
6981
}
7082

@@ -74,7 +86,9 @@ public struct GRPCAsyncClientStreamingCall<Request: Sendable, Response: Sendable
7486
public var status: GRPCStatus {
7587
get async {
7688
// force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
77-
try! await self.responseParts.status.get()
89+
return await self.withRPCCancellation {
90+
try! await self.responseParts.status.get()
91+
}
7892
}
7993
}
8094

Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,22 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
4949

5050
// MARK: - Response Parts
5151

52+
private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
53+
return try await withTaskCancellationHandler(operation: fn) {
54+
self.cancel()
55+
}
56+
}
57+
5258
/// The initial metadata returned from the server.
5359
///
5460
/// - Important: The initial metadata will only be available when the first response has been
5561
/// received. However, it is not necessary for the response to have been consumed before reading
5662
/// this property.
5763
public var initialMetadata: HPACKHeaders {
5864
get async throws {
59-
try await self.responseParts.initialMetadata.get()
65+
try await self.withRPCCancellation {
66+
try await self.responseParts.initialMetadata.get()
67+
}
6068
}
6169
}
6270

@@ -65,7 +73,9 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
6573
/// - Important: Awaiting this property will suspend until the responses have been consumed.
6674
public var trailingMetadata: HPACKHeaders {
6775
get async throws {
68-
try await self.responseParts.trailingMetadata.get()
76+
try await self.withRPCCancellation {
77+
try await self.responseParts.trailingMetadata.get()
78+
}
6979
}
7080
}
7181

@@ -75,7 +85,9 @@ public struct GRPCAsyncServerStreamingCall<Request: Sendable, Response: Sendable
7585
public var status: GRPCStatus {
7686
get async {
7787
// force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
78-
try! await self.responseParts.status.get()
88+
await self.withRPCCancellation {
89+
try! await self.responseParts.status.get()
90+
}
7991
}
8092
}
8193

Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,20 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
4141

4242
// MARK: - Response Parts
4343

44+
private func withRPCCancellation<R: Sendable>(_ fn: () async throws -> R) async rethrows -> R {
45+
return try await withTaskCancellationHandler(operation: fn) {
46+
self.cancel()
47+
}
48+
}
49+
4450
/// The initial metadata returned from the server.
4551
///
4652
/// - Important: The initial metadata will only be available when the response has been received.
4753
public var initialMetadata: HPACKHeaders {
4854
get async throws {
49-
try await self.responseParts.initialMetadata.get()
55+
try await self.withRPCCancellation {
56+
try await self.responseParts.initialMetadata.get()
57+
}
5058
}
5159
}
5260

@@ -56,7 +64,9 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
5664
/// Callers should rely on the `status` of the call for the canonical outcome.
5765
public var response: Response {
5866
get async throws {
59-
try await self.responseParts.response.get()
67+
try await self.withRPCCancellation {
68+
try await self.responseParts.response.get()
69+
}
6070
}
6171
}
6272

@@ -65,7 +75,9 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
6575
/// - Important: Awaiting this property will suspend until the responses have been consumed.
6676
public var trailingMetadata: HPACKHeaders {
6777
get async throws {
68-
try await self.responseParts.trailingMetadata.get()
78+
try await self.withRPCCancellation {
79+
try await self.responseParts.trailingMetadata.get()
80+
}
6981
}
7082
}
7183

@@ -75,7 +87,9 @@ public struct GRPCAsyncUnaryCall<Request: Sendable, Response: Sendable>: Sendabl
7587
public var status: GRPCStatus {
7688
get async {
7789
// force-try acceptable because any error is encapsulated in a successful GRPCStatus future.
78-
try! await self.responseParts.status.get()
90+
await self.withRPCCancellation {
91+
try! await self.responseParts.status.get()
92+
}
7993
}
8094
}
8195

Tests/GRPCTests/AsyncAwaitSupport/AsyncClientTests.swift

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,4 +414,126 @@ final class AsyncClientCancellationTests: GRPCTestCase {
414414
XCTAssertFalse(error is CancellationError)
415415
}
416416
}
417+
418+
func testCancelUnary() async throws {
419+
// We don't want the RPC to complete before we cancel it so use the never resolving service.
420+
let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
421+
422+
do {
423+
let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
424+
let task = Task { try await get.initialMetadata }
425+
task.cancel()
426+
await XCTAssertThrowsError(try await task.value)
427+
}
428+
429+
do {
430+
let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
431+
let task = Task { try await get.response }
432+
task.cancel()
433+
await XCTAssertThrowsError(try await task.value)
434+
}
435+
436+
do {
437+
let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
438+
let task = Task { try await get.trailingMetadata }
439+
task.cancel()
440+
await XCTAssertNoThrowAsync(try await task.value)
441+
}
442+
443+
do {
444+
let get = echo.makeGetCall(.with { $0.text = "foo bar baz" })
445+
let task = Task { await get.status }
446+
task.cancel()
447+
let status = await task.value
448+
XCTAssertEqual(status.code, .cancelled)
449+
}
450+
}
451+
452+
func testCancelClientStreaming() async throws {
453+
// We don't want the RPC to complete before we cancel it so use the never resolving service.
454+
let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
455+
456+
do {
457+
let collect = echo.makeCollectCall()
458+
let task = Task { try await collect.initialMetadata }
459+
task.cancel()
460+
await XCTAssertThrowsError(try await task.value)
461+
}
462+
463+
do {
464+
let collect = echo.makeCollectCall()
465+
let task = Task { try await collect.response }
466+
task.cancel()
467+
await XCTAssertThrowsError(try await task.value)
468+
}
469+
470+
do {
471+
let collect = echo.makeCollectCall()
472+
let task = Task { try await collect.trailingMetadata }
473+
task.cancel()
474+
await XCTAssertNoThrowAsync(try await task.value)
475+
}
476+
477+
do {
478+
let collect = echo.makeCollectCall()
479+
let task = Task { await collect.status }
480+
task.cancel()
481+
let status = await task.value
482+
XCTAssertEqual(status.code, .cancelled)
483+
}
484+
}
485+
486+
func testCancelServerStreaming() async throws {
487+
// We don't want the RPC to complete before we cancel it so use the never resolving service.
488+
let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
489+
490+
do {
491+
let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
492+
let task = Task { try await expand.initialMetadata }
493+
task.cancel()
494+
await XCTAssertThrowsError(try await task.value)
495+
}
496+
497+
do {
498+
let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
499+
let task = Task { try await expand.trailingMetadata }
500+
task.cancel()
501+
await XCTAssertNoThrowAsync(try await task.value)
502+
}
503+
504+
do {
505+
let expand = echo.makeExpandCall(.with { $0.text = "foo bar baz" })
506+
let task = Task { await expand.status }
507+
task.cancel()
508+
let status = await task.value
509+
XCTAssertEqual(status.code, .cancelled)
510+
}
511+
}
512+
513+
func testCancelBidirectionalStreaming() async throws {
514+
// We don't want the RPC to complete before we cancel it so use the never resolving service.
515+
let echo = try self.startServerAndClient(service: NeverResolvingEchoProvider())
516+
517+
do {
518+
let update = echo.makeUpdateCall()
519+
let task = Task { try await update.initialMetadata }
520+
task.cancel()
521+
await XCTAssertThrowsError(try await task.value)
522+
}
523+
524+
do {
525+
let update = echo.makeUpdateCall()
526+
let task = Task { try await update.trailingMetadata }
527+
task.cancel()
528+
await XCTAssertNoThrowAsync(try await task.value)
529+
}
530+
531+
do {
532+
let update = echo.makeUpdateCall()
533+
let task = Task { await update.status }
534+
task.cancel()
535+
let status = await task.value
536+
XCTAssertEqual(status.code, .cancelled)
537+
}
538+
}
417539
}

Tests/GRPCTests/AsyncAwaitSupport/XCTest+AsyncAwait.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ internal func XCTAssertThrowsError<T>(
3030
}
3131
}
3232

33+
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
34+
internal func XCTAssertNoThrowAsync<T>(
35+
_ expression: @autoclosure () async throws -> T,
36+
file: StaticString = #filePath,
37+
line: UInt = #line
38+
) async {
39+
do {
40+
_ = try await expression()
41+
} catch {
42+
XCTFail("Expression throw error '\(error)'", file: file, line: line)
43+
}
44+
}
45+
3346
private enum TaskResult<Result> {
3447
case operation(Result)
3548
case cancellation

Tests/GRPCTests/InterceptedRPCCancellationTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ final class InterceptedRPCCancellationTests: GRPCTestCase {
3333
}
3434

3535
// Interceptor checks that a "magic" header is present.
36-
let serverInterceptors = EchoServerInterceptors(MagicRequiredServerInterceptor.init)
36+
let serverInterceptors = EchoServerInterceptors({ MagicRequiredServerInterceptor() })
3737
let server = try Server.insecure(group: group)
3838
.withLogger(self.serverLogger)
3939
.withServiceProviders([EchoProvider(interceptors: serverInterceptors)])

0 commit comments

Comments
 (0)