diff --git a/Package@swift-6.swift b/Package@swift-6.swift index e7ee68e33..e7b1c268a 100644 --- a/Package@swift-6.swift +++ b/Package@swift-6.swift @@ -251,6 +251,7 @@ extension Target { .nioCore, .nioHTTP2, .nioTLS, + .nioExtras, .cgrpcZlib, .dequeModule, ], diff --git a/Sources/GRPCHTTP2Core/Server/CommonHTTP2ServerTransport.swift b/Sources/GRPCHTTP2Core/Server/CommonHTTP2ServerTransport.swift new file mode 100644 index 000000000..a900b5402 --- /dev/null +++ b/Sources/GRPCHTTP2Core/Server/CommonHTTP2ServerTransport.swift @@ -0,0 +1,254 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package import GRPCCore +package import NIOCore +package import NIOExtras +private import NIOHTTP2 +private import Synchronization + +/// Provides the common functionality for a `NIO`-based server transport. +/// +/// - SeeAlso: ``HTTP2ListenerFactory``. +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) +package final class CommonHTTP2ServerTransport< + ListenerFactory: HTTP2ListenerFactory +>: ServerTransport, ListeningServerTransport { + private let eventLoopGroup: any EventLoopGroup + private let address: SocketAddress + private let listeningAddressState: Mutex + private let serverQuiescingHelper: ServerQuiescingHelper + private let factory: ListenerFactory + + private enum State { + case idle(EventLoopPromise) + case listening(EventLoopFuture) + case closedOrInvalidAddress(RuntimeError) + + var listeningAddressFuture: EventLoopFuture { + get throws { + switch self { + case .idle(let eventLoopPromise): + return eventLoopPromise.futureResult + case .listening(let eventLoopFuture): + return eventLoopFuture + case .closedOrInvalidAddress(let runtimeError): + throw runtimeError + } + } + } + + enum OnBound { + case succeedPromise(_ promise: EventLoopPromise, address: SocketAddress) + case failPromise(_ promise: EventLoopPromise, error: RuntimeError) + } + + mutating func addressBound( + _ address: NIOCore.SocketAddress?, + userProvidedAddress: SocketAddress + ) -> OnBound { + switch self { + case .idle(let listeningAddressPromise): + if let address { + self = .listening(listeningAddressPromise.futureResult) + return .succeedPromise(listeningAddressPromise, address: SocketAddress(address)) + } else if userProvidedAddress.virtualSocket != nil { + self = .listening(listeningAddressPromise.futureResult) + return .succeedPromise(listeningAddressPromise, address: userProvidedAddress) + } else { + assertionFailure("Unknown address type") + let invalidAddressError = RuntimeError( + code: .transportError, + message: "Unknown address type returned by transport." + ) + self = .closedOrInvalidAddress(invalidAddressError) + return .failPromise(listeningAddressPromise, error: invalidAddressError) + } + + case .listening, .closedOrInvalidAddress: + fatalError("Invalid state: addressBound should only be called once and when in idle state") + } + } + + enum OnClose { + case failPromise(EventLoopPromise, error: RuntimeError) + case doNothing + } + + mutating func close() -> OnClose { + let serverStoppedError = RuntimeError( + code: .serverIsStopped, + message: """ + There is no listening address bound for this server: there may have been \ + an error which caused the transport to close, or it may have shut down. + """ + ) + + switch self { + case .idle(let listeningAddressPromise): + self = .closedOrInvalidAddress(serverStoppedError) + return .failPromise(listeningAddressPromise, error: serverStoppedError) + + case .listening: + self = .closedOrInvalidAddress(serverStoppedError) + return .doNothing + + case .closedOrInvalidAddress: + return .doNothing + } + } + } + + /// The listening address for this server transport. + /// + /// It is an `async` property because it will only return once the address has been successfully bound. + /// + /// - Throws: A runtime error will be thrown if the address could not be bound or is not bound any + /// longer, because the transport isn't listening anymore. It can also throw if the transport returned an + /// invalid address. + package var listeningAddress: SocketAddress { + get async throws { + try await self.listeningAddressState + .withLock { try $0.listeningAddressFuture } + .get() + } + } + + package init( + address: SocketAddress, + eventLoopGroup: any EventLoopGroup, + quiescingHelper: ServerQuiescingHelper, + listenerFactory: ListenerFactory + ) { + self.eventLoopGroup = eventLoopGroup + self.address = address + + let eventLoop = eventLoopGroup.any() + self.listeningAddressState = Mutex(.idle(eventLoop.makePromise())) + + self.factory = listenerFactory + self.serverQuiescingHelper = quiescingHelper + } + + package func listen( + _ streamHandler: @escaping @Sendable (RPCStream) async -> Void + ) async throws { + defer { + switch self.listeningAddressState.withLock({ $0.close() }) { + case .failPromise(let promise, let error): + promise.fail(error) + case .doNothing: + () + } + } + + let serverChannel = try await self.factory.makeListeningChannel( + eventLoopGroup: self.eventLoopGroup, + address: self.address, + serverQuiescingHelper: self.serverQuiescingHelper + ) + + let action = self.listeningAddressState.withLock { + $0.addressBound( + serverChannel.channel.localAddress, + userProvidedAddress: self.address + ) + } + switch action { + case .succeedPromise(let promise, let address): + promise.succeed(address) + case .failPromise(let promise, let error): + promise.fail(error) + } + + try await serverChannel.executeThenClose { inbound in + try await withThrowingDiscardingTaskGroup { group in + for try await (connectionChannel, streamMultiplexer) in inbound { + group.addTask { + try await self.handleConnection( + connectionChannel, + multiplexer: streamMultiplexer, + streamHandler: streamHandler + ) + } + } + } + } + } + + private func handleConnection( + _ connection: NIOAsyncChannel, + multiplexer: ChannelPipeline.SynchronousOperations.HTTP2StreamMultiplexer, + streamHandler: @escaping @Sendable (RPCStream) async -> Void + ) async throws { + try await connection.executeThenClose { inbound, _ in + await withDiscardingTaskGroup { group in + group.addTask { + do { + for try await _ in inbound {} + } catch { + // We don't want to close the channel if one connection throws. + return + } + } + + do { + for try await (stream, descriptor) in multiplexer.inbound { + group.addTask { + await self.handleStream(stream, handler: streamHandler, descriptor: descriptor) + } + } + } catch { + return + } + } + } + } + + private func handleStream( + _ stream: NIOAsyncChannel, + handler streamHandler: @escaping @Sendable (RPCStream) async -> Void, + descriptor: EventLoopFuture + ) async { + // It's okay to ignore these errors: + // - If we get an error because the http2Stream failed to close, then there's nothing we can do + // - If we get an error because the inner closure threw, then the only possible scenario in which + // that could happen is if methodDescriptor.get() throws - in which case, it means we never got + // the RPC metadata, which means we can't do anything either and it's okay to just kill the stream. + try? await stream.executeThenClose { inbound, outbound in + guard let descriptor = try? await descriptor.get() else { + return + } + + let rpcStream = RPCStream( + descriptor: descriptor, + inbound: RPCAsyncSequence(wrapping: inbound), + outbound: RPCWriter.Closable( + wrapping: ServerConnection.Stream.Outbound( + responseWriter: outbound, + http2Stream: stream + ) + ) + ) + + await streamHandler(rpcStream) + } + } + + package func beginGracefulShutdown() { + self.serverQuiescingHelper.initiateShutdown(promise: nil) + } +} diff --git a/Sources/GRPCHTTP2Core/Server/HTTP2ListenerFactory.swift b/Sources/GRPCHTTP2Core/Server/HTTP2ListenerFactory.swift new file mode 100644 index 000000000..900799a61 --- /dev/null +++ b/Sources/GRPCHTTP2Core/Server/HTTP2ListenerFactory.swift @@ -0,0 +1,35 @@ +/* + * Copyright 2024, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package import NIOCore +package import NIOExtras + +/// A factory to produce `NIOAsyncChannel`s to listen for new HTTP/2 connections. +/// +/// - SeeAlso: ``CommonHTTP2ServerTransport`` +@available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) +package protocol HTTP2ListenerFactory: Sendable { + typealias AcceptedChannel = ( + ChannelPipeline.SynchronousOperations.HTTP2ConnectionChannel, + ChannelPipeline.SynchronousOperations.HTTP2StreamMultiplexer + ) + + func makeListeningChannel( + eventLoopGroup: any EventLoopGroup, + address: SocketAddress, + serverQuiescingHelper: ServerQuiescingHelper + ) async throws -> NIOAsyncChannel +} diff --git a/Sources/GRPCHTTP2TransportNIOPosix/HTTP2ServerTransport+Posix.swift b/Sources/GRPCHTTP2TransportNIOPosix/HTTP2ServerTransport+Posix.swift index 43da51ae0..d7bda6350 100644 --- a/Sources/GRPCHTTP2TransportNIOPosix/HTTP2ServerTransport+Posix.swift +++ b/Sources/GRPCHTTP2TransportNIOPosix/HTTP2ServerTransport+Posix.swift @@ -55,108 +55,78 @@ extension HTTP2ServerTransport { /// } /// ``` @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) - public final class Posix: ServerTransport, ListeningServerTransport { - private let address: GRPCHTTP2Core.SocketAddress - private let config: Config - private let eventLoopGroup: MultiThreadedEventLoopGroup - private let serverQuiescingHelper: ServerQuiescingHelper - - private enum State { - case idle(EventLoopPromise) - case listening(EventLoopFuture) - case closedOrInvalidAddress(RuntimeError) - - var listeningAddressFuture: EventLoopFuture { - get throws { - switch self { - case .idle(let eventLoopPromise): - return eventLoopPromise.futureResult - case .listening(let eventLoopFuture): - return eventLoopFuture - case .closedOrInvalidAddress(let runtimeError): - throw runtimeError - } - } - } - - enum OnBound { - case succeedPromise( - _ promise: EventLoopPromise, - address: GRPCHTTP2Core.SocketAddress - ) - case failPromise( - _ promise: EventLoopPromise, - error: RuntimeError - ) - } - - mutating func addressBound( - _ address: NIOCore.SocketAddress?, - userProvidedAddress: GRPCHTTP2Core.SocketAddress - ) -> OnBound { - switch self { - case .idle(let listeningAddressPromise): - if let address { - self = .listening(listeningAddressPromise.futureResult) - return .succeedPromise( - listeningAddressPromise, - address: GRPCHTTP2Core.SocketAddress(address) - ) - - } else if userProvidedAddress.virtualSocket != nil { - self = .listening(listeningAddressPromise.futureResult) - return .succeedPromise(listeningAddressPromise, address: userProvidedAddress) - - } else { - assertionFailure("Unknown address type") - let invalidAddressError = RuntimeError( + public struct Posix: ServerTransport, ListeningServerTransport { + private struct ListenerFactory: HTTP2ListenerFactory { + let config: Config + + func makeListeningChannel( + eventLoopGroup: any EventLoopGroup, + address: GRPCHTTP2Core.SocketAddress, + serverQuiescingHelper: ServerQuiescingHelper + ) async throws -> NIOAsyncChannel { + #if canImport(NIOSSL) + let sslContext: NIOSSLContext? + + switch self.config.transportSecurity.wrapped { + case .plaintext: + sslContext = nil + case .tls(let tlsConfig): + do { + sslContext = try NIOSSLContext(configuration: TLSConfiguration(tlsConfig)) + } catch { + throw RuntimeError( code: .transportError, - message: "Unknown address type returned by transport." + message: "Couldn't create SSL context, check your TLS configuration.", + cause: error ) - self = .closedOrInvalidAddress(invalidAddressError) - return .failPromise(listeningAddressPromise, error: invalidAddressError) } - - case .listening, .closedOrInvalidAddress: - fatalError( - "Invalid state: addressBound should only be called once and when in idle state" - ) } - } - - enum OnClose { - case failPromise( - EventLoopPromise, - error: RuntimeError - ) - case doNothing - } - - mutating func close() -> OnClose { - let serverStoppedError = RuntimeError( - code: .serverIsStopped, - message: """ - There is no listening address bound for this server: there may have been \ - an error which caused the transport to close, or it may have shut down. - """ - ) + #endif - switch self { - case .idle(let listeningAddressPromise): - self = .closedOrInvalidAddress(serverStoppedError) - return .failPromise(listeningAddressPromise, error: serverStoppedError) + let serverChannel = try await ServerBootstrap(group: eventLoopGroup) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .serverChannelInitializer { channel in + let quiescingHandler = serverQuiescingHelper.makeServerChannelHandler(channel: channel) + return channel.pipeline.addHandler(quiescingHandler) + } + .bind(to: address) { channel in + channel.eventLoop.makeCompletedFuture { + #if canImport(NIOSSL) + if let sslContext { + try channel.pipeline.syncOperations.addHandler( + NIOSSLServerHandler(context: sslContext) + ) + } + #endif + + let requireALPN: Bool + let scheme: Scheme + switch self.config.transportSecurity.wrapped { + case .plaintext: + requireALPN = false + scheme = .http + case .tls(let tlsConfig): + requireALPN = tlsConfig.requireALPN + scheme = .https + } - case .listening: - self = .closedOrInvalidAddress(serverStoppedError) - return .doNothing + return try channel.pipeline.syncOperations.configureGRPCServerPipeline( + channel: channel, + compressionConfig: self.config.compression, + connectionConfig: self.config.connection, + http2Config: self.config.http2, + rpcConfig: self.config.rpc, + requireALPN: requireALPN, + scheme: scheme + ) + } + } - case .closedOrInvalidAddress: - return .doNothing - } + return serverChannel } } - private let listeningAddressState: Mutex + private let underlyingTransport: CommonHTTP2ServerTransport /// The listening address for this server transport. /// @@ -167,9 +137,7 @@ extension HTTP2ServerTransport { /// invalid address. public var listeningAddress: GRPCHTTP2Core.SocketAddress { get async throws { - try await self.listeningAddressState - .withLock { try $0.listeningAddressFuture } - .get() + try await self.underlyingTransport.listeningAddress } } @@ -184,178 +152,24 @@ extension HTTP2ServerTransport { config: Config, eventLoopGroup: MultiThreadedEventLoopGroup = .singletonMultiThreadedEventLoopGroup ) { - self.address = address - self.config = config - self.eventLoopGroup = eventLoopGroup - self.serverQuiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup) - - let eventLoop = eventLoopGroup.any() - self.listeningAddressState = Mutex(.idle(eventLoop.makePromise())) + let factory = ListenerFactory(config: config) + let helper = ServerQuiescingHelper(group: eventLoopGroup) + self.underlyingTransport = CommonHTTP2ServerTransport( + address: address, + eventLoopGroup: eventLoopGroup, + quiescingHelper: helper, + listenerFactory: factory + ) } public func listen( _ streamHandler: @escaping @Sendable (RPCStream) async -> Void ) async throws { - defer { - switch self.listeningAddressState.withLock({ $0.close() }) { - case .failPromise(let promise, let error): - promise.fail(error) - case .doNothing: - () - } - } - - #if canImport(NIOSSL) - let nioSSLContext: NIOSSLContext? - switch self.config.transportSecurity.wrapped { - case .plaintext: - nioSSLContext = nil - case .tls(let tlsConfig): - do { - nioSSLContext = try NIOSSLContext(configuration: TLSConfiguration(tlsConfig)) - } catch { - throw RuntimeError( - code: .transportError, - message: "Couldn't create SSL context, check your TLS configuration.", - cause: error - ) - } - } - #endif - - let serverChannel = try await ServerBootstrap(group: self.eventLoopGroup) - .serverChannelOption( - ChannelOptions.socketOption(.so_reuseaddr), - value: 1 - ) - .serverChannelInitializer { channel in - let quiescingHandler = self.serverQuiescingHelper.makeServerChannelHandler( - channel: channel - ) - return channel.pipeline.addHandler(quiescingHandler) - } - .bind(to: self.address) { channel in - channel.eventLoop.makeCompletedFuture { - #if canImport(NIOSSL) - if let nioSSLContext { - try channel.pipeline.syncOperations.addHandler( - NIOSSLServerHandler(context: nioSSLContext) - ) - } - #endif - - let requireALPN: Bool - let scheme: Scheme - switch self.config.transportSecurity.wrapped { - case .plaintext: - requireALPN = false - scheme = .http - case .tls(let tlsConfig): - requireALPN = tlsConfig.requireALPN - scheme = .https - } - - return try channel.pipeline.syncOperations.configureGRPCServerPipeline( - channel: channel, - compressionConfig: self.config.compression, - connectionConfig: self.config.connection, - http2Config: self.config.http2, - rpcConfig: self.config.rpc, - requireALPN: requireALPN, - scheme: scheme - ) - } - } - - let action = self.listeningAddressState.withLock { - $0.addressBound( - serverChannel.channel.localAddress, - userProvidedAddress: self.address - ) - } - switch action { - case .succeedPromise(let promise, let address): - promise.succeed(address) - case .failPromise(let promise, let error): - promise.fail(error) - } - - try await serverChannel.executeThenClose { inbound in - try await withThrowingDiscardingTaskGroup { group in - for try await (connectionChannel, streamMultiplexer) in inbound { - group.addTask { - try await self.handleConnection( - connectionChannel, - multiplexer: streamMultiplexer, - streamHandler: streamHandler - ) - } - } - } - } - } - - private func handleConnection( - _ connection: NIOAsyncChannel, - multiplexer: ChannelPipeline.SynchronousOperations.HTTP2StreamMultiplexer, - streamHandler: @escaping @Sendable (RPCStream) async -> Void - ) async throws { - try await connection.executeThenClose { inbound, _ in - await withDiscardingTaskGroup { group in - group.addTask { - do { - for try await _ in inbound {} - } catch { - // We don't want to close the channel if one connection throws. - return - } - } - - do { - for try await (stream, descriptor) in multiplexer.inbound { - group.addTask { - await self.handleStream(stream, handler: streamHandler, descriptor: descriptor) - } - } - } catch { - return - } - } - } - } - - private func handleStream( - _ stream: NIOAsyncChannel, - handler streamHandler: @escaping @Sendable (RPCStream) async -> Void, - descriptor: EventLoopFuture - ) async { - // It's okay to ignore these errors: - // - If we get an error because the http2Stream failed to close, then there's nothing we can do - // - If we get an error because the inner closure threw, then the only possible scenario in which - // that could happen is if methodDescriptor.get() throws - in which case, it means we never got - // the RPC metadata, which means we can't do anything either and it's okay to just kill the stream. - try? await stream.executeThenClose { inbound, outbound in - guard let descriptor = try? await descriptor.get() else { - return - } - - let rpcStream = RPCStream( - descriptor: descriptor, - inbound: RPCAsyncSequence(wrapping: inbound), - outbound: RPCWriter.Closable( - wrapping: ServerConnection.Stream.Outbound( - responseWriter: outbound, - http2Stream: stream - ) - ) - ) - - await streamHandler(rpcStream) - } + try await self.underlyingTransport.listen(streamHandler) } public func beginGracefulShutdown() { - self.serverQuiescingHelper.initiateShutdown(promise: nil) + self.underlyingTransport.beginGracefulShutdown() } } } diff --git a/Sources/GRPCHTTP2TransportNIOTransportServices/HTTP2ServerTransport+TransportServices.swift b/Sources/GRPCHTTP2TransportNIOTransportServices/HTTP2ServerTransport+TransportServices.swift index 12001861d..94628e215 100644 --- a/Sources/GRPCHTTP2TransportNIOTransportServices/HTTP2ServerTransport+TransportServices.swift +++ b/Sources/GRPCHTTP2TransportNIOTransportServices/HTTP2ServerTransport+TransportServices.swift @@ -29,101 +29,58 @@ private import Synchronization extension HTTP2ServerTransport { /// A NIO Transport Services-backed implementation of a server transport. @available(macOS 15.0, iOS 18.0, watchOS 11.0, tvOS 18.0, visionOS 2.0, *) - public final class TransportServices: ServerTransport, ListeningServerTransport { - private let address: GRPCHTTP2Core.SocketAddress - private let config: Config - private let eventLoopGroup: NIOTSEventLoopGroup - private let serverQuiescingHelper: ServerQuiescingHelper - - private enum State { - case idle(EventLoopPromise) - case listening(EventLoopFuture) - case closedOrInvalidAddress(RuntimeError) - - var listeningAddressFuture: EventLoopFuture { - get throws { - switch self { - case .idle(let eventLoopPromise): - return eventLoopPromise.futureResult - case .listening(let eventLoopFuture): - return eventLoopFuture - case .closedOrInvalidAddress(let runtimeError): - throw runtimeError - } + public struct TransportServices: ServerTransport, ListeningServerTransport { + private struct ListenerFactory: HTTP2ListenerFactory { + let config: Config + + func makeListeningChannel( + eventLoopGroup: any EventLoopGroup, + address: GRPCHTTP2Core.SocketAddress, + serverQuiescingHelper: ServerQuiescingHelper + ) async throws -> NIOAsyncChannel { + let bootstrap: NIOTSListenerBootstrap + + let requireALPN: Bool + let scheme: Scheme + switch self.config.transportSecurity.wrapped { + case .plaintext: + requireALPN = false + scheme = .http + bootstrap = NIOTSListenerBootstrap(group: eventLoopGroup) + + case .tls(let tlsConfig): + requireALPN = tlsConfig.requireALPN + scheme = .https + bootstrap = NIOTSListenerBootstrap(group: eventLoopGroup) + .tlsOptions(try NWProtocolTLS.Options(tlsConfig)) } - } - - enum OnBound { - case succeedPromise( - _ promise: EventLoopPromise, - address: GRPCHTTP2Core.SocketAddress - ) - case failPromise( - _ promise: EventLoopPromise, - error: RuntimeError - ) - } - - mutating func addressBound(_ address: NIOCore.SocketAddress?) -> OnBound { - switch self { - case .idle(let listeningAddressPromise): - if let address { - self = .listening(listeningAddressPromise.futureResult) - return .succeedPromise( - listeningAddressPromise, - address: GRPCHTTP2Core.SocketAddress(address) - ) - } else { - assertionFailure("Unknown address type") - let invalidAddressError = RuntimeError( - code: .transportError, - message: "Unknown address type returned by transport." - ) - self = .closedOrInvalidAddress(invalidAddressError) - return .failPromise(listeningAddressPromise, error: invalidAddressError) + let serverChannel = + try await bootstrap + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .serverChannelInitializer { channel in + let quiescingHandler = serverQuiescingHelper.makeServerChannelHandler(channel: channel) + return channel.pipeline.addHandler(quiescingHandler) + } + .bind(to: address) { channel in + channel.eventLoop.makeCompletedFuture { + return try channel.pipeline.syncOperations.configureGRPCServerPipeline( + channel: channel, + compressionConfig: self.config.compression, + connectionConfig: self.config.connection, + http2Config: self.config.http2, + rpcConfig: self.config.rpc, + requireALPN: requireALPN, + scheme: scheme + ) + } } - case .listening, .closedOrInvalidAddress: - fatalError( - "Invalid state: addressBound should only be called once and when in idle state" - ) - } - } - - enum OnClose { - case failPromise( - EventLoopPromise, - error: RuntimeError - ) - case doNothing - } - - mutating func close() -> OnClose { - let serverStoppedError = RuntimeError( - code: .serverIsStopped, - message: """ - There is no listening address bound for this server: there may have been \ - an error which caused the transport to close, or it may have shut down. - """ - ) - - switch self { - case .idle(let listeningAddressPromise): - self = .closedOrInvalidAddress(serverStoppedError) - return .failPromise(listeningAddressPromise, error: serverStoppedError) - - case .listening: - self = .closedOrInvalidAddress(serverStoppedError) - return .doNothing - - case .closedOrInvalidAddress: - return .doNothing - } + return serverChannel } } - private let listeningAddressState: Mutex + private let underlyingTransport: CommonHTTP2ServerTransport /// The listening address for this server transport. /// @@ -134,9 +91,7 @@ extension HTTP2ServerTransport { /// invalid address. public var listeningAddress: GRPCHTTP2Core.SocketAddress { get async throws { - try await self.listeningAddressState - .withLock { try $0.listeningAddressFuture } - .get() + try await self.underlyingTransport.listeningAddress } } @@ -151,156 +106,24 @@ extension HTTP2ServerTransport { config: Config, eventLoopGroup: NIOTSEventLoopGroup = .singletonNIOTSEventLoopGroup ) { - self.address = address - self.config = config - self.eventLoopGroup = eventLoopGroup - self.serverQuiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup) - - let eventLoop = eventLoopGroup.any() - self.listeningAddressState = Mutex(.idle(eventLoop.makePromise())) + let factory = ListenerFactory(config: config) + let helper = ServerQuiescingHelper(group: eventLoopGroup) + self.underlyingTransport = CommonHTTP2ServerTransport( + address: address, + eventLoopGroup: eventLoopGroup, + quiescingHelper: helper, + listenerFactory: factory + ) } public func listen( _ streamHandler: @escaping @Sendable (RPCStream) async -> Void ) async throws { - defer { - switch self.listeningAddressState.withLock({ $0.close() }) { - case .failPromise(let promise, let error): - promise.fail(error) - case .doNothing: - () - } - } - - let bootstrap: NIOTSListenerBootstrap - - let requireALPN: Bool - let scheme: Scheme - switch self.config.transportSecurity.wrapped { - case .plaintext: - requireALPN = false - scheme = .http - bootstrap = NIOTSListenerBootstrap(group: self.eventLoopGroup) - - case .tls(let tlsConfig): - requireALPN = tlsConfig.requireALPN - scheme = .https - bootstrap = NIOTSListenerBootstrap(group: self.eventLoopGroup) - .tlsOptions(try NWProtocolTLS.Options(tlsConfig)) - } - - let serverChannel = - try await bootstrap - .serverChannelOption( - ChannelOptions.socketOption(.so_reuseaddr), - value: 1 - ) - .serverChannelInitializer { channel in - let quiescingHandler = self.serverQuiescingHelper.makeServerChannelHandler( - channel: channel - ) - return channel.pipeline.addHandler(quiescingHandler) - } - .bind(to: self.address) { channel in - channel.eventLoop.makeCompletedFuture { - return try channel.pipeline.syncOperations.configureGRPCServerPipeline( - channel: channel, - compressionConfig: self.config.compression, - connectionConfig: self.config.connection, - http2Config: self.config.http2, - rpcConfig: self.config.rpc, - requireALPN: requireALPN, - scheme: scheme - ) - } - } - - let action = self.listeningAddressState.withLock { - $0.addressBound(serverChannel.channel.localAddress) - } - switch action { - case .succeedPromise(let promise, let address): - promise.succeed(address) - case .failPromise(let promise, let error): - promise.fail(error) - } - - try await serverChannel.executeThenClose { inbound in - try await withThrowingDiscardingTaskGroup { group in - for try await (connectionChannel, streamMultiplexer) in inbound { - group.addTask { - try await self.handleConnection( - connectionChannel, - multiplexer: streamMultiplexer, - streamHandler: streamHandler - ) - } - } - } - } - } - - private func handleConnection( - _ connection: NIOAsyncChannel, - multiplexer: ChannelPipeline.SynchronousOperations.HTTP2StreamMultiplexer, - streamHandler: @escaping @Sendable (RPCStream) async -> Void - ) async throws { - try await connection.executeThenClose { inbound, _ in - await withDiscardingTaskGroup { group in - group.addTask { - do { - for try await _ in inbound {} - } catch { - // We don't want to close the channel if one connection throws. - return - } - } - - do { - for try await (stream, descriptor) in multiplexer.inbound { - group.addTask { - await self.handleStream(stream, handler: streamHandler, descriptor: descriptor) - } - } - } catch { - return - } - } - } - } - - private func handleStream( - _ stream: NIOAsyncChannel, - handler streamHandler: @escaping @Sendable (RPCStream) async -> Void, - descriptor: EventLoopFuture - ) async { - // It's okay to ignore these errors: - // - If we get an error because the http2Stream failed to close, then there's nothing we can do - // - If we get an error because the inner closure threw, then the only possible scenario in which - // that could happen is if methodDescriptor.get() throws - in which case, it means we never got - // the RPC metadata, which means we can't do anything either and it's okay to just kill the stream. - try? await stream.executeThenClose { inbound, outbound in - guard let descriptor = try? await descriptor.get() else { - return - } - - let rpcStream = RPCStream( - descriptor: descriptor, - inbound: RPCAsyncSequence(wrapping: inbound), - outbound: RPCWriter.Closable( - wrapping: ServerConnection.Stream.Outbound( - responseWriter: outbound, - http2Stream: stream - ) - ) - ) - - await streamHandler(rpcStream) - } + try await self.underlyingTransport.listen(streamHandler) } public func beginGracefulShutdown() { - self.serverQuiescingHelper.initiateShutdown(promise: nil) + self.underlyingTransport.beginGracefulShutdown() } } }