Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ internal enum ClientStreamExecutor {

switch result {
case .success:
stream.finish()
await stream.finish()
case .failure(let error):
stream.finish(throwing: error)
await stream.finish(throwing: error)
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/GRPCCore/Call/Server/Internal/ServerRPCExecutor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct ServerRPCExecutor {
// Stream can't be handled; write an error status and close.
let status = Status(code: Status.Code(error.code), message: error.message)
try? await stream.outbound.write(.status(status, error.metadata))
stream.outbound.finish()
await stream.outbound.finish()
}
}

Expand Down Expand Up @@ -231,7 +231,7 @@ struct ServerRPCExecutor {
}

try? await outbound.write(.status(status, metadata))
outbound.finish()
await outbound.finish()
}

@inlinable
Expand Down
2 changes: 1 addition & 1 deletion Sources/GRPCCore/Call/Server/RPCRouter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ extension RPCRouter {
// If this throws then the stream must be closed which we can't do anything about, so ignore
// any error.
try? await stream.outbound.write(.status(.rpcNotImplemented, [:]))
stream.outbound.finish()
await stream.outbound.finish()
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions Sources/GRPCCore/Streaming/RPCWriter+Closable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ extension RPCWriter {
/// All writes after ``finish()`` has been called should result in an error
/// being thrown.
@inlinable
public func finish() {
self.writer.finish()
public func finish() async {
await self.writer.finish()
}

/// Indicate to the writer that no more writes are to be accepted because an error occurred.
///
/// All writes after ``finish(throwing:)`` has been called should result in an error
/// being thrown.
@inlinable
public func finish(throwing error: any Error) {
self.writer.finish(throwing: error)
public func finish(throwing error: any Error) async {
await self.writer.finish(throwing: error)
}
}
}
4 changes: 2 additions & 2 deletions Sources/GRPCCore/Streaming/RPCWriterProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ public protocol ClosableRPCWriterProtocol<Element>: RPCWriterProtocol {
///
/// All writes after ``finish()`` has been called should result in an error
/// being thrown.
func finish()
func finish() async

/// Indicate to the writer that no more writes are to be accepted because an error occurred.
///
/// All writes after ``finish(throwing:)`` has been called should result in an error
/// being thrown.
func finish(throwing error: any Error)
func finish(throwing error: any Error) async
}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
Expand Down
92 changes: 50 additions & 42 deletions Sources/GRPCInProcessTransport/InProcessClientTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ public final class InProcessClientTransport: ClientTransport {
}

for (clientStream, serverStream) in openStreams {
clientStream.outbound.finish(throwing: CancellationError())
serverStream.outbound.finish(throwing: CancellationError())
await clientStream.outbound.finish(throwing: CancellationError())
await serverStream.outbound.finish(throwing: CancellationError())
}
}

Expand Down Expand Up @@ -265,7 +265,7 @@ public final class InProcessClientTransport: ClientTransport {
try Task.checkCancellation()
}

let streamID = try self.state.withLock { state in
let acceptStream: Result<Int, RPCError> = self.state.withLock { state in
switch state {
case .unconnected:
// The state cannot be unconnected because if it was, then the above
Expand All @@ -281,56 +281,64 @@ public final class InProcessClientTransport: ClientTransport {
connectedState.openStreams[streamID] = (clientStream, serverStream)
connectedState.nextStreamID += 1
state = .connected(connectedState)
return .success(streamID)
} catch let acceptStreamError as RPCError {
serverStream.outbound.finish(throwing: acceptStreamError)
clientStream.outbound.finish(throwing: acceptStreamError)
throw acceptStreamError
return .failure(acceptStreamError)
} catch {
serverStream.outbound.finish(throwing: error)
clientStream.outbound.finish(throwing: error)
throw RPCError(code: .unknown, message: "Unknown error: \(error).")
return .failure(RPCError(code: .unknown, message: "Unknown error: \(error)."))
}
return streamID

case .closed:
let error = RPCError(
code: .failedPrecondition,
message: "The client transport is closed."
)
serverStream.outbound.finish(throwing: error)
clientStream.outbound.finish(throwing: error)
throw error
let error = RPCError(code: .failedPrecondition, message: "The client transport is closed.")
return .failure(error)
}
}

defer {
clientStream.outbound.finish()

let maybeEndContinuation = self.state.withLock { state in
switch state {
case .unconnected:
// The state cannot be unconnected at this point, because if we made
// it this far, it's because the transport was connected.
// Once connected, it's impossible to transition back to unconnected,
// so this is an invalid state.
fatalError("Invalid state")
case .connected(var connectedState):
connectedState.openStreams.removeValue(forKey: streamID)
state = .connected(connectedState)
case .closed(var closedState):
closedState.openStreams.removeValue(forKey: streamID)
state = .closed(closedState)
if closedState.openStreams.isEmpty {
// This was the last open stream: signal the closure of the client.
return closedState.signalEndContinuation
}
}
return nil
switch acceptStream {
case .success(let streamID):
let streamHandlingResult: Result<T, any Error>
do {
let result = try await closure(clientStream)
streamHandlingResult = .success(result)
} catch {
streamHandlingResult = .failure(error)
}
maybeEndContinuation?.finish()

await clientStream.outbound.finish()
self.removeStream(id: streamID)

return try streamHandlingResult.get()

case .failure(let error):
await serverStream.outbound.finish(throwing: error)
await clientStream.outbound.finish(throwing: error)
throw error
}
}

return try await closure(clientStream)
private func removeStream(id streamID: Int) {
let maybeEndContinuation = self.state.withLock { state in
switch state {
case .unconnected:
// The state cannot be unconnected at this point, because if we made
// it this far, it's because the transport was connected.
// Once connected, it's impossible to transition back to unconnected,
// so this is an invalid state.
fatalError("Invalid state")
case .connected(var connectedState):
connectedState.openStreams.removeValue(forKey: streamID)
state = .connected(connectedState)
case .closed(var closedState):
closedState.openStreams.removeValue(forKey: streamID)
state = .closed(closedState)
if closedState.openStreams.isEmpty {
// This was the last open stream: signal the closure of the client.
return closedState.signalEndContinuation
}
}
return nil
}
maybeEndContinuation?.finish()
}

/// Returns the execution configuration for a given method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {

try await stream.outbound.write(contentsOf: response)
try await stream.outbound.write(.status(Status(code: .ok, message: ""), [:]))
stream.outbound.finish()
await stream.outbound.finish()
}
}

Expand All @@ -90,7 +90,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
// All error codes are valid status codes, '!' is safe.
let status = Status(code: Status.Code(error.code), message: error.message)
try await stream.outbound.write(.status(status, error.metadata))
stream.outbound.finish()
await stream.outbound.finish()
}
}

Expand All @@ -99,7 +99,7 @@ extension ClientRPCExecutorTestHarness.ServerStreamHandler {
XCTFail("Server accepted unexpected stream")
let status = Status(code: .unknown, message: "Unexpected stream")
try await stream.outbound.write(.status(status, [:]))
stream.outbound.finish()
await stream.outbound.finish()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ final class ServerRPCExecutorTests: XCTestCase {
let harness = ServerRPCExecutorTestHarness()
try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.metadata(["foo": "bar"]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(
Expand All @@ -43,7 +43,7 @@ final class ServerRPCExecutorTests: XCTestCase {
try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.metadata(["foo": "bar"]))
try await inbound.write(.message([0]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(
Expand All @@ -64,7 +64,7 @@ final class ServerRPCExecutorTests: XCTestCase {
try await inbound.write(.message([0]))
try await inbound.write(.message([1]))
try await inbound.write(.message([2]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(
Expand Down Expand Up @@ -95,7 +95,7 @@ final class ServerRPCExecutorTests: XCTestCase {
} producer: { inbound in
try await inbound.write(.metadata(["foo": "bar"]))
try await inbound.write(.message(Array("\"hello\"".utf8)))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(
Expand Down Expand Up @@ -126,7 +126,7 @@ final class ServerRPCExecutorTests: XCTestCase {
try await inbound.write(.metadata(["foo": "bar"]))
try await inbound.write(.message(Array("\"hello\"".utf8)))
try await inbound.write(.message(Array("\"world\"".utf8)))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(
Expand All @@ -152,7 +152,7 @@ final class ServerRPCExecutorTests: XCTestCase {
}
} producer: { inbound in
try await inbound.write(.metadata(["foo": "bar"]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(
Expand All @@ -168,7 +168,7 @@ final class ServerRPCExecutorTests: XCTestCase {
func testEmptyInbound() async throws {
let harness = ServerRPCExecutorTestHarness()
try await harness.execute(handler: .echo) { inbound in
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let part = try await outbound.collect().first
XCTAssertStatus(part) { status, _ in
Expand All @@ -181,7 +181,7 @@ final class ServerRPCExecutorTests: XCTestCase {
let harness = ServerRPCExecutorTestHarness()
try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.message([0]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let part = try await outbound.collect().first
XCTAssertStatus(part) { status, _ in
Expand All @@ -193,7 +193,7 @@ final class ServerRPCExecutorTests: XCTestCase {
func testInboundStreamThrows() async throws {
let harness = ServerRPCExecutorTestHarness()
try await harness.execute(handler: .echo) { inbound in
inbound.finish(throwing: RPCError(code: .aborted, message: ""))
await inbound.finish(throwing: RPCError(code: .aborted, message: ""))
} consumer: { outbound in
let part = try await outbound.collect().first
XCTAssertStatus(part) { status, _ in
Expand All @@ -207,7 +207,7 @@ final class ServerRPCExecutorTests: XCTestCase {
let harness = ServerRPCExecutorTestHarness()
try await harness.execute(handler: .throwing(SomeError())) { inbound in
try await inbound.write(.metadata([:]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let part = try await outbound.collect().first
XCTAssertStatus(part) { status, _ in
Expand All @@ -221,7 +221,7 @@ final class ServerRPCExecutorTests: XCTestCase {
let harness = ServerRPCExecutorTestHarness()
try await harness.execute(handler: .throwing(error)) { inbound in
try await inbound.write(.metadata([:]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let part = try await outbound.collect().first
XCTAssertStatus(part) { status, metadata in
Expand All @@ -248,7 +248,7 @@ final class ServerRPCExecutorTests: XCTestCase {
return ServerResponse.Stream(error: RPCError(code: .failedPrecondition, message: ""))
} producer: { inbound in
try await inbound.write(.metadata(["grpc-timeout": "1000n"]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let part = try await outbound.collect().first
XCTAssertStatus(part) { status, _ in
Expand Down Expand Up @@ -278,7 +278,7 @@ final class ServerRPCExecutorTests: XCTestCase {
)
} producer: { inbound in
try await inbound.write(.metadata([:]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let part = try await outbound.collect().first
XCTAssertStatus(part) { status, metadata in
Expand All @@ -303,7 +303,7 @@ final class ServerRPCExecutorTests: XCTestCase {

try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.metadata([:]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(parts, [.metadata([:]), .status(.ok, [:])])
Expand All @@ -328,7 +328,7 @@ final class ServerRPCExecutorTests: XCTestCase {

try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.metadata([:]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: ""), [:])])
Expand All @@ -346,7 +346,7 @@ final class ServerRPCExecutorTests: XCTestCase {

try await harness.execute(handler: .echo) { inbound in
try await inbound.write(.metadata([:]))
inbound.finish()
await inbound.finish()
} consumer: { outbound in
let parts = try await outbound.collect()
XCTAssertEqual(parts, [.status(Status(code: .unavailable, message: "Unavailable"), [:])])
Expand Down
Loading
Loading