diff --git a/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift b/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift index 749969b2d..472639835 100644 --- a/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift +++ b/Sources/GRPCCore/Call/Client/Internal/ClientStreamExecutor.swift @@ -95,9 +95,9 @@ internal enum ClientStreamExecutor { switch result { case .success: - stream.finish() + await stream.finish() case .failure(let error): - stream.finish(throwing: error) + await stream.finish(throwing: error) } } diff --git a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift index f2261f1fc..49e3b713e 100644 --- a/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift +++ b/Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift @@ -58,7 +58,7 @@ struct ServerRPCExecutor { // Stream can't be handled; write an error status and close. let status = Status(code: Status.Code(error.code), message: error.message) try? await stream.outbound.write(.status(status, error.metadata)) - stream.outbound.finish() + await stream.outbound.finish() } } @@ -231,7 +231,7 @@ struct ServerRPCExecutor { } try? await outbound.write(.status(status, metadata)) - outbound.finish() + await outbound.finish() } @inlinable diff --git a/Sources/GRPCCore/Call/Server/RPCRouter.swift b/Sources/GRPCCore/Call/Server/RPCRouter.swift index e6670d523..b6ae3f074 100644 --- a/Sources/GRPCCore/Call/Server/RPCRouter.swift +++ b/Sources/GRPCCore/Call/Server/RPCRouter.swift @@ -155,7 +155,7 @@ extension RPCRouter { // If this throws then the stream must be closed which we can't do anything about, so ignore // any error. try? await stream.outbound.write(.status(.rpcNotImplemented, [:])) - stream.outbound.finish() + await stream.outbound.finish() } } } diff --git a/Sources/GRPCCore/Streaming/RPCWriter+Closable.swift b/Sources/GRPCCore/Streaming/RPCWriter+Closable.swift index 7462766b7..dda689464 100644 --- a/Sources/GRPCCore/Streaming/RPCWriter+Closable.swift +++ b/Sources/GRPCCore/Streaming/RPCWriter+Closable.swift @@ -55,8 +55,8 @@ extension RPCWriter { /// All writes after ``finish()`` has been called should result in an error /// being thrown. @inlinable - public func finish() { - self.writer.finish() + public func finish() async { + await self.writer.finish() } /// Indicate to the writer that no more writes are to be accepted because an error occurred. @@ -64,8 +64,8 @@ extension RPCWriter { /// All writes after ``finish(throwing:)`` has been called should result in an error /// being thrown. @inlinable - public func finish(throwing error: any Error) { - self.writer.finish(throwing: error) + public func finish(throwing error: any Error) async { + await self.writer.finish(throwing: error) } } } diff --git a/Sources/GRPCCore/Streaming/RPCWriterProtocol.swift b/Sources/GRPCCore/Streaming/RPCWriterProtocol.swift index b2607f733..5841f5802 100644 --- a/Sources/GRPCCore/Streaming/RPCWriterProtocol.swift +++ b/Sources/GRPCCore/Streaming/RPCWriterProtocol.swift @@ -57,13 +57,13 @@ public protocol ClosableRPCWriterProtocol: RPCWriterProtocol { /// /// All writes after ``finish()`` has been called should result in an error /// being thrown. - func finish() + func finish() async /// Indicate to the writer that no more writes are to be accepted because an error occurred. /// /// All writes after ``finish(throwing:)`` has been called should result in an error /// being thrown. - func finish(throwing error: any Error) + func finish(throwing error: any Error) async } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) diff --git a/Sources/GRPCInProcessTransport/InProcessClientTransport.swift b/Sources/GRPCInProcessTransport/InProcessClientTransport.swift index aded232a2..ddbb4a8bd 100644 --- a/Sources/GRPCInProcessTransport/InProcessClientTransport.swift +++ b/Sources/GRPCInProcessTransport/InProcessClientTransport.swift @@ -179,8 +179,8 @@ public final class InProcessClientTransport: ClientTransport { } for (clientStream, serverStream) in openStreams { - clientStream.outbound.finish(throwing: CancellationError()) - serverStream.outbound.finish(throwing: CancellationError()) + await clientStream.outbound.finish(throwing: CancellationError()) + await serverStream.outbound.finish(throwing: CancellationError()) } } @@ -265,7 +265,7 @@ public final class InProcessClientTransport: ClientTransport { try Task.checkCancellation() } - let streamID = try self.state.withLock { state in + let acceptStream: Result = self.state.withLock { state in switch state { case .unconnected: // The state cannot be unconnected because if it was, then the above @@ -281,56 +281,64 @@ public final class InProcessClientTransport: ClientTransport { connectedState.openStreams[streamID] = (clientStream, serverStream) connectedState.nextStreamID += 1 state = .connected(connectedState) + return .success(streamID) } catch let acceptStreamError as RPCError { - serverStream.outbound.finish(throwing: acceptStreamError) - clientStream.outbound.finish(throwing: acceptStreamError) - throw acceptStreamError + return .failure(acceptStreamError) } catch { - serverStream.outbound.finish(throwing: error) - clientStream.outbound.finish(throwing: error) - throw RPCError(code: .unknown, message: "Unknown error: \(error).") + return .failure(RPCError(code: .unknown, message: "Unknown error: \(error).")) } - return streamID case .closed: - let error = RPCError( - code: .failedPrecondition, - message: "The client transport is closed." - ) - serverStream.outbound.finish(throwing: error) - clientStream.outbound.finish(throwing: error) - throw error + let error = RPCError(code: .failedPrecondition, message: "The client transport is closed.") + return .failure(error) } } - defer { - clientStream.outbound.finish() - - let maybeEndContinuation = self.state.withLock { state in - switch state { - case .unconnected: - // The state cannot be unconnected at this point, because if we made - // it this far, it's because the transport was connected. - // Once connected, it's impossible to transition back to unconnected, - // so this is an invalid state. - fatalError("Invalid state") - case .connected(var connectedState): - connectedState.openStreams.removeValue(forKey: streamID) - state = .connected(connectedState) - case .closed(var closedState): - closedState.openStreams.removeValue(forKey: streamID) - state = .closed(closedState) - if closedState.openStreams.isEmpty { - // This was the last open stream: signal the closure of the client. - return closedState.signalEndContinuation - } - } - return nil + switch acceptStream { + case .success(let streamID): + let streamHandlingResult: Result + do { + let result = try await closure(clientStream) + streamHandlingResult = .success(result) + } catch { + streamHandlingResult = .failure(error) } - maybeEndContinuation?.finish() + + await clientStream.outbound.finish() + self.removeStream(id: streamID) + + return try streamHandlingResult.get() + + case .failure(let error): + await serverStream.outbound.finish(throwing: error) + await clientStream.outbound.finish(throwing: error) + throw error } + } - return try await closure(clientStream) + private func removeStream(id streamID: Int) { + let maybeEndContinuation = self.state.withLock { state in + switch state { + case .unconnected: + // The state cannot be unconnected at this point, because if we made + // it this far, it's because the transport was connected. + // Once connected, it's impossible to transition back to unconnected, + // so this is an invalid state. + fatalError("Invalid state") + case .connected(var connectedState): + connectedState.openStreams.removeValue(forKey: streamID) + state = .connected(connectedState) + case .closed(var closedState): + closedState.openStreams.removeValue(forKey: streamID) + state = .closed(closedState) + if closedState.openStreams.isEmpty { + // This was the last open stream: signal the closure of the client. + return closedState.signalEndContinuation + } + } + return nil + } + maybeEndContinuation?.finish() } /// Returns the execution configuration for a given method. diff --git a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness+ServerBehavior.swift b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness+ServerBehavior.swift index 4f0ab1b14..0c2ab936f 100644 --- a/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness+ServerBehavior.swift +++ b/Tests/GRPCCoreTests/Call/Client/Internal/ClientRPCExecutorTestSupport/ClientRPCExecutorTestHarness+ServerBehavior.swift @@ -74,7 +74,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler { try await stream.outbound.write(contentsOf: response) try await stream.outbound.write(.status(Status(code: .ok, message: ""), [:])) - stream.outbound.finish() + await stream.outbound.finish() } } @@ -90,7 +90,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler { // All error codes are valid status codes, '!' is safe. let status = Status(code: Status.Code(error.code), message: error.message) try await stream.outbound.write(.status(status, error.metadata)) - stream.outbound.finish() + await stream.outbound.finish() } } @@ -99,7 +99,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler { XCTFail("Server accepted unexpected stream") let status = Status(code: .unknown, message: "Unexpected stream") try await stream.outbound.write(.status(status, [:])) - stream.outbound.finish() + await stream.outbound.finish() } } diff --git a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift index 0393de2c9..5d2aa0029 100644 --- a/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift +++ b/Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift @@ -24,7 +24,7 @@ final class ServerRPCExecutorTests: XCTestCase { let harness = ServerRPCExecutorTestHarness() try await harness.execute(handler: .echo) { inbound in try await inbound.write(.metadata(["foo": "bar"])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual( @@ -42,7 +42,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute(handler: .echo) { inbound in try await inbound.write(.metadata(["foo": "bar"])) try await inbound.write(.message([0])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual( @@ -63,7 +63,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await inbound.write(.message([0])) try await inbound.write(.message([1])) try await inbound.write(.message([2])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual( @@ -94,7 +94,7 @@ final class ServerRPCExecutorTests: XCTestCase { } producer: { inbound in try await inbound.write(.metadata(["foo": "bar"])) try await inbound.write(.message(Array("\"hello\"".utf8))) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual( @@ -125,7 +125,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await inbound.write(.metadata(["foo": "bar"])) try await inbound.write(.message(Array("\"hello\"".utf8))) try await inbound.write(.message(Array("\"world\"".utf8))) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual( @@ -151,7 +151,7 @@ final class ServerRPCExecutorTests: XCTestCase { } } producer: { inbound in try await inbound.write(.metadata(["foo": "bar"])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual( @@ -167,7 +167,7 @@ final class ServerRPCExecutorTests: XCTestCase { func testEmptyInbound() async throws { let harness = ServerRPCExecutorTestHarness() try await harness.execute(handler: .echo) { inbound in - inbound.finish() + await inbound.finish() } consumer: { outbound in let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in @@ -180,7 +180,7 @@ final class ServerRPCExecutorTests: XCTestCase { let harness = ServerRPCExecutorTestHarness() try await harness.execute(handler: .echo) { inbound in try await inbound.write(.message([0])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in @@ -192,7 +192,7 @@ final class ServerRPCExecutorTests: XCTestCase { func testInboundStreamThrows() async throws { let harness = ServerRPCExecutorTestHarness() try await harness.execute(handler: .echo) { inbound in - inbound.finish(throwing: RPCError(code: .aborted, message: "")) + await inbound.finish(throwing: RPCError(code: .aborted, message: "")) } consumer: { outbound in let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in @@ -206,7 +206,7 @@ final class ServerRPCExecutorTests: XCTestCase { let harness = ServerRPCExecutorTestHarness() try await harness.execute(handler: .throwing(SomeError())) { inbound in try await inbound.write(.metadata([:])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in @@ -220,7 +220,7 @@ final class ServerRPCExecutorTests: XCTestCase { let harness = ServerRPCExecutorTestHarness() try await harness.execute(handler: .throwing(error)) { inbound in try await inbound.write(.metadata([:])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let part = try await outbound.collect().first XCTAssertStatus(part) { status, metadata in @@ -247,7 +247,7 @@ final class ServerRPCExecutorTests: XCTestCase { return ServerResponse.Stream(error: RPCError(code: .failedPrecondition, message: "")) } producer: { inbound in try await inbound.write(.metadata(["grpc-timeout": "1000n"])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let part = try await outbound.collect().first XCTAssertStatus(part) { status, _ in @@ -277,7 +277,7 @@ final class ServerRPCExecutorTests: XCTestCase { ) } producer: { inbound in try await inbound.write(.metadata([:])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let part = try await outbound.collect().first XCTAssertStatus(part) { status, metadata in @@ -302,7 +302,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute(handler: .echo) { inbound in try await inbound.write(.metadata([:])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual(parts, [.metadata([:]), .status(.ok, [:])]) @@ -327,7 +327,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute(handler: .echo) { inbound in try await inbound.write(.metadata([:])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: ""), [:])]) @@ -345,7 +345,7 @@ final class ServerRPCExecutorTests: XCTestCase { try await harness.execute(handler: .echo) { inbound in try await inbound.write(.metadata([:])) - inbound.finish() + await inbound.finish() } consumer: { outbound in let parts = try await outbound.collect() XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: "Unavailable"), [:])]) diff --git a/Tests/GRPCCoreTests/GRPCServerTests.swift b/Tests/GRPCCoreTests/GRPCServerTests.swift index f5afcb2a1..d8771d81e 100644 --- a/Tests/GRPCCoreTests/GRPCServerTests.swift +++ b/Tests/GRPCCoreTests/GRPCServerTests.swift @@ -55,7 +55,7 @@ final class GRPCServerTests: XCTestCase { ) { stream in try await stream.outbound.write(.metadata([:])) try await stream.outbound.write(.message([3, 1, 4, 1, 5])) - stream.outbound.finish() + await stream.outbound.finish() var responseParts = stream.inbound.makeAsyncIterator() let metadata = try await responseParts.next() @@ -86,7 +86,7 @@ final class GRPCServerTests: XCTestCase { try await stream.outbound.write(.message([4])) try await stream.outbound.write(.message([1])) try await stream.outbound.write(.message([5])) - stream.outbound.finish() + await stream.outbound.finish() var responseParts = stream.inbound.makeAsyncIterator() let metadata = try await responseParts.next() @@ -113,7 +113,7 @@ final class GRPCServerTests: XCTestCase { ) { stream in try await stream.outbound.write(.metadata([:])) try await stream.outbound.write(.message([3, 1, 4, 1, 5])) - stream.outbound.finish() + await stream.outbound.finish() var responseParts = stream.inbound.makeAsyncIterator() let metadata = try await responseParts.next() @@ -144,7 +144,7 @@ final class GRPCServerTests: XCTestCase { for byte in [3, 1, 4, 1, 5] as [UInt8] { try await stream.outbound.write(.message([byte])) } - stream.outbound.finish() + await stream.outbound.finish() var responseParts = stream.inbound.makeAsyncIterator() let metadata = try await responseParts.next() @@ -172,7 +172,7 @@ final class GRPCServerTests: XCTestCase { options: .defaults ) { stream in try await stream.outbound.write(.metadata([:])) - stream.outbound.finish() + await stream.outbound.finish() var responseParts = stream.inbound.makeAsyncIterator() let status = try await responseParts.next() @@ -194,7 +194,7 @@ final class GRPCServerTests: XCTestCase { ) { stream in try await stream.outbound.write(.metadata([:])) try await stream.outbound.write(.message([i])) - stream.outbound.finish() + await stream.outbound.finish() var responseParts = stream.inbound.makeAsyncIterator() let metadata = try await responseParts.next() @@ -231,7 +231,7 @@ final class GRPCServerTests: XCTestCase { options: .defaults ) { stream in try await stream.outbound.write(.metadata([:])) - stream.outbound.finish() + await stream.outbound.finish() let parts = try await stream.inbound.collect() XCTAssertStatus(parts.first) { status, _ in @@ -256,7 +256,7 @@ final class GRPCServerTests: XCTestCase { options: .defaults ) { stream in try await stream.outbound.write(.metadata([:])) - stream.outbound.finish() + await stream.outbound.finish() let parts = try await stream.inbound.collect() XCTAssertStatus(parts.first) { status, _ in @@ -306,7 +306,7 @@ final class GRPCServerTests: XCTestCase { server.beginGracefulShutdown() try await stream.outbound.write(.message([0])) - stream.outbound.finish() + await stream.outbound.finish() let message = try await iterator.next() XCTAssertMessage(message) { XCTAssertEqual($0, [0]) } @@ -368,7 +368,7 @@ final class GRPCServerTests: XCTestCase { ) { stream in try await stream.outbound.write(.metadata([:])) try await stream.outbound.write(.message([0])) - stream.outbound.finish() + await stream.outbound.finish() // Don't need to validate the response, just that the server is running. let parts = try await stream.inbound.collect() XCTAssertEqual(parts.count, 3) diff --git a/Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift b/Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift index 391e1e624..96a172311 100644 --- a/Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift +++ b/Tests/GRPCHTTP2CoreTests/Client/Connection/GRPCChannelTests.swift @@ -371,7 +371,7 @@ final class GRPCChannelTests: XCTestCase { switch state { case .shutdown: // Happens when shutting-down has been initiated, so finish the RPC. - stream.outbound.finish() + await stream.outbound.finish() let part2 = try await iterator.next() switch part2 { @@ -444,7 +444,7 @@ final class GRPCChannelTests: XCTestCase { group.addTask { try await channel.withStream(descriptor: .echoGet, options: .defaults) { stream in try await stream.outbound.write(.metadata([:])) - stream.outbound.finish() + await stream.outbound.finish() for try await part in stream.inbound { switch part { @@ -824,7 +824,7 @@ extension GRPCChannel { options: .defaults ) { stream in try await stream.outbound.write(.metadata([:])) - stream.outbound.finish() + await stream.outbound.finish() for try await part in stream.inbound { switch part { diff --git a/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift b/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift index 1209bc9d1..d579ca532 100644 --- a/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift +++ b/Tests/GRPCInProcessTransportTests/InProcessClientTransportTests.swift @@ -163,7 +163,7 @@ final class InProcessClientTransportTests: XCTestCase { options: .defaults ) { stream in try await stream.outbound.write(.message([1])) - stream.outbound.finish() + await stream.outbound.finish() let receivedMessages = try await stream.inbound.reduce(into: []) { $0.append($1) } XCTAssertEqual(receivedMessages, [.message([42])]) @@ -174,7 +174,7 @@ final class InProcessClientTransportTests: XCTestCase { try await server.listen { stream in let receivedMessages = try? await stream.inbound.reduce(into: []) { $0.append($1) } try? await stream.outbound.write(RPCResponsePart.message([42])) - stream.outbound.finish() + await stream.outbound.finish() XCTAssertEqual(receivedMessages, [.message([1])]) }