From 7a6a4b61162f8d77a3a9f5f88511a93b5d568977 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Mon, 6 Oct 2025 10:16:00 +0100 Subject: [PATCH] Add support for starting a server from accepted conns Motivation: Some applications receive a file descriptor for an already accepted TCP connection. There's currently no API to create a server from an fd like this. Modifications: - Add an API to the server builder allowing a server to be created from an accepted socket - Add tests Result: Can create a server from an accepted socket fd --- Sources/GRPC/Server.swift | 225 +++++++++++----- Sources/GRPC/ServerBuilder.swift | 18 ++ Tests/GRPCTests/AcceptedServerTests.swift | 300 ++++++++++++++++++++++ 3 files changed, 475 insertions(+), 68 deletions(-) create mode 100644 Tests/GRPCTests/AcceptedServerTests.swift diff --git a/Sources/GRPC/Server.swift b/Sources/GRPC/Server.swift index 491f05447..56e5b5abf 100644 --- a/Sources/GRPC/Server.swift +++ b/Sources/GRPC/Server.swift @@ -98,25 +98,7 @@ public final class Server: @unchecked Sendable { } #if canImport(NIOSSL) - // Making a `NIOSSLContext` is expensive, we should only do it once per TLS configuration so - // we'll do it now, before accepting connections. Unfortunately our API isn't throwing so we'll - // only surface any error when initializing a child channel. - // - // 'nil' means we're not using TLS, or we're using the Network.framework TLS backend. If we're - // using the Network.framework TLS backend we'll apply the settings just below. - let sslContext: Result? - - if let tlsConfiguration = configuration.tlsConfiguration { - do { - sslContext = try tlsConfiguration.makeNIOSSLContext().map { .success($0) } - } catch { - sslContext = .failure(error) - } - - } else { - // No TLS configuration, no SSL context. - sslContext = nil - } + let sslContext = Self.makeNIOSSLContext(configuration: configuration) #endif // canImport(NIOSSL) #if canImport(Network) @@ -152,53 +134,10 @@ public final class Server: @unchecked Sendable { ) // Set the handlers that are applied to the accepted Channels .childChannelInitializer { channel in - var configuration = configuration - configuration.logger[metadataKey: MetadataKey.connectionID] = "\(UUID().uuidString)" - configuration.logger.addIPAddressMetadata( - local: channel.localAddress, - remote: channel.remoteAddress - ) - - do { - let sync = channel.pipeline.syncOperations + Self.configureAcceptedChannel(channel, configuration: configuration) { sync in #if canImport(NIOSSL) - if let sslContext = try sslContext?.get() { - let sslHandler: NIOSSLServerHandler - if let verify = configuration.tlsConfiguration?.nioSSLCustomVerificationCallback { - sslHandler = NIOSSLServerHandler( - context: sslContext, - customVerificationCallback: verify - ) - } else { - sslHandler = NIOSSLServerHandler(context: sslContext) - } - - try sync.addHandler(sslHandler) - } + try Self.addNIOSSLHandler(sslContext, configuration: configuration, sync: sync) #endif // canImport(NIOSSL) - - // Configures the pipeline based on whether the connection uses TLS or not. - try sync.addHandler(GRPCServerPipelineConfigurator(configuration: configuration)) - - // Work around the zero length write issue, if needed. - let requiresZeroLengthWorkaround = PlatformSupport.requiresZeroLengthWriteWorkaround( - group: configuration.eventLoopGroup, - hasTLS: configuration.tlsConfiguration != nil - ) - if requiresZeroLengthWorkaround, - #available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) - { - try sync.addHandler(NIOFilterEmptyWritesHandler()) - } - } catch { - return channel.eventLoop.makeFailedFuture(error) - } - - // Run the debug initializer, if there is one. - if let debugAcceptedChannelInitializer = configuration.debugChannelInitializer { - return debugAcceptedChannelInitializer(channel) - } else { - return channel.eventLoop.makeSucceededVoidFuture() } } @@ -210,11 +149,108 @@ public final class Server: @unchecked Sendable { ) } + #if canImport(NIOSSL) + private static func makeNIOSSLContext( + configuration: Configuration + ) -> Result? { + // Making a `NIOSSLContext` is expensive, we should only do it once per TLS configuration so + // we'll do it now, before accepting connections. Unfortunately our API isn't throwing so we'll + // only surface any error when initializing a child channel. + // + // 'nil' means we're not using TLS, or we're using the Network.framework TLS backend. If we're + // using the Network.framework TLS backend we'll apply the settings just below. + let sslContext: Result? + + if let tlsConfiguration = configuration.tlsConfiguration { + do { + sslContext = try tlsConfiguration.makeNIOSSLContext().map { .success($0) } + } catch { + sslContext = .failure(error) + } + + } else { + // No TLS configuration, no SSL context. + sslContext = nil + } + + return sslContext + } + + private static func addNIOSSLHandler( + _ sslContext: Result?, + configuration: Configuration, + sync: ChannelPipeline.SynchronousOperations + ) throws { + if let sslContext = try sslContext?.get() { + let sslHandler: NIOSSLServerHandler + if let verify = configuration.tlsConfiguration?.nioSSLCustomVerificationCallback { + sslHandler = NIOSSLServerHandler( + context: sslContext, + customVerificationCallback: verify + ) + } else { + sslHandler = NIOSSLServerHandler(context: sslContext) + } + + try sync.addHandler(sslHandler) + } + } + #endif // canImport(NIOSSL) + + private static func configureAcceptedChannel( + _ channel: Channel, + configuration: Configuration, + addNIOSSLIfNecessary: (ChannelPipeline.SynchronousOperations) throws -> Void + ) -> EventLoopFuture { + var configuration = configuration + configuration.logger[metadataKey: MetadataKey.connectionID] = "\(UUID().uuidString)" + configuration.logger.addIPAddressMetadata( + local: channel.localAddress, + remote: channel.remoteAddress + ) + + do { + let sync = channel.pipeline.syncOperations + try addNIOSSLIfNecessary(sync) + + // Configures the pipeline based on whether the connection uses TLS or not. + try sync.addHandler(GRPCServerPipelineConfigurator(configuration: configuration)) + + // Work around the zero length write issue, if needed. + let requiresZeroLengthWorkaround = PlatformSupport.requiresZeroLengthWriteWorkaround( + group: configuration.eventLoopGroup, + hasTLS: configuration.tlsConfiguration != nil + ) + if requiresZeroLengthWorkaround, + #available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) + { + try sync.addHandler(NIOFilterEmptyWritesHandler()) + } + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + + // Run the debug initializer, if there is one. + if let debugAcceptedChannelInitializer = configuration.debugChannelInitializer { + return debugAcceptedChannelInitializer(channel) + } else { + return channel.eventLoop.makeSucceededVoidFuture() + } + } + /// Starts a server with the given configuration. See `Server.Configuration` for the options /// available to configure the server. public static func start(configuration: Configuration) -> EventLoopFuture { - let quiescingHelper = ServerQuiescingHelper(group: configuration.eventLoopGroup) + switch configuration.target.wrapped { + case .connectedSocket(let handle) where configuration.connectedSocketTargetIsAcceptedConnection: + return Self.startServerFromAcceptedConnection(handle: handle, configuration: configuration) + case .connectedSocket, .hostAndPort, .unixDomainSocket, .socketAddress, .vsockAddress: + return Self.startServer(configuration: configuration) + } + } + private static func startServer(configuration: Configuration) -> EventLoopFuture { + let quiescingHelper = ServerQuiescingHelper(group: configuration.eventLoopGroup) return self.makeBootstrap(configuration: configuration) .serverChannelInitializer { channel in channel.pipeline.addHandler(quiescingHelper.makeServerChannelHandler(channel: channel)) @@ -229,13 +265,53 @@ public final class Server: @unchecked Sendable { } } + private static func startServerFromAcceptedConnection( + handle: NIOBSDSocket.Handle, + configuration: Configuration + ) -> EventLoopFuture { + guard let bootstrap = ClientBootstrap(validatingGroup: configuration.eventLoopGroup) else { + let status = GRPCStatus( + code: .unimplemented, + message: """ + You must use a NIOPosix EventLoopGroup to create a server from an already accepted \ + socket. + """ + ) + return configuration.eventLoopGroup.any().makeFailedFuture(status) + } + + #if canImport(NIOSSL) + let sslContext = Self.makeNIOSSLContext(configuration: configuration) + #endif // canImport(NIOSSL) + + return bootstrap.channelInitializer { channel in + Self.configureAcceptedChannel(channel, configuration: configuration) { sync in + #if canImport(NIOSSL) + try Self.addNIOSSLHandler(sslContext, configuration: configuration, sync: sync) + #endif // canImport(NIOSSL) + } + }.withConnectedSocket(handle).map { channel in + Server( + channel: channel, + quiescingHelper: nil, + errorDelegate: configuration.errorDelegate + ) + } + } + + /// The listening server channel. + /// + /// If the server was created from an already accepted connection then this channel will + /// be for the accepted connection. public let channel: Channel - private let quiescingHelper: ServerQuiescingHelper + + /// Quiescing helper. `nil` if `channel` is for an accepted connection. + private let quiescingHelper: ServerQuiescingHelper? private var errorDelegate: ServerErrorDelegate? private init( channel: Channel, - quiescingHelper: ServerQuiescingHelper, + quiescingHelper: ServerQuiescingHelper?, errorDelegate: ServerErrorDelegate? ) { self.channel = channel @@ -264,7 +340,13 @@ public final class Server: @unchecked Sendable { /// Initiates a graceful shutdown. Existing RPCs may run to completion, any new RPCs or /// connections will be rejected. public func initiateGracefulShutdown(promise: EventLoopPromise?) { - self.quiescingHelper.initiateShutdown(promise: promise) + if let quiescingHelper = self.quiescingHelper { + quiescingHelper.initiateShutdown(promise: promise) + } else { + // No quiescing helper: the channel must be for an already accepted connection. + self.channel.closeFuture.cascade(to: promise) + self.channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + } } /// Initiates a graceful shutdown. Existing RPCs may run to completion, any new RPCs or @@ -436,6 +518,13 @@ extension Server { /// CORS configuration for gRPC-Web support. public var webCORS = Configuration.CORS() + /// Indicates whether a `connectedSocket` ``target`` is treated as an accepted connection. + /// + /// If ``target`` is a `connectedSocket` then this flag indicates whether that socket is for + /// an already accepted connection. If the value is `false` then the socket is treated as a + /// listener. This value is ignored if ``target`` is any value other than `connectedSocket`. + public var connectedSocketTargetIsAcceptedConnection: Bool = false + #if canImport(NIOSSL) /// Create a `Configuration` with some pre-defined defaults. /// diff --git a/Sources/GRPC/ServerBuilder.swift b/Sources/GRPC/ServerBuilder.swift index c783e0d24..819a60c7b 100644 --- a/Sources/GRPC/ServerBuilder.swift +++ b/Sources/GRPC/ServerBuilder.swift @@ -79,6 +79,24 @@ extension Server { self.configuration.tlsConfiguration = self.maybeTLS return Server.start(configuration: self.configuration) } + + /// Create a gRPC server from the file descriptor of an already accepted TCP connection. + /// + /// - Parameter handle: The handle to the accepted socket. + /// - Important: This is only supported with `NIOPosix` (i.e. when using a + /// `MultiThreadedEventLoopGroup` or one of its loops and TLS configured via `NIOSSL`). + /// - Warning: By calling this function you hand responsibility of the socket to gRPC. + /// Crucially you must **not** close the socket directly after calling this function, gRPC + /// will do it for you. + /// - Returns: A configured gRPC server. + public func fromAcceptedConnection( + takingOwnershipOf handle: NIOBSDSocket.Handle + ) -> EventLoopFuture { + self.configuration.target = .connectedSocket(handle) + self.configuration.connectedSocketTargetIsAcceptedConnection = true + self.configuration.tlsConfiguration = self.maybeTLS + return Server.start(configuration: self.configuration) + } } } diff --git a/Tests/GRPCTests/AcceptedServerTests.swift b/Tests/GRPCTests/AcceptedServerTests.swift new file mode 100644 index 000000000..f9f995d6b --- /dev/null +++ b/Tests/GRPCTests/AcceptedServerTests.swift @@ -0,0 +1,300 @@ +/* + * Copyright 2025, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import EchoImplementation +import EchoModel +import GRPC +import GRPCSampleData +import Logging +import NIOConcurrencyHelpers +import NIOPosix +import XCTest + +#if canImport(Darwin) +import Darwin +private let sys_bind = Darwin.bind +private let sys_listen = Darwin.listen +private let sys_close = Darwin.close +private let sys_accept = Darwin.accept +private let sys_strerror = Darwin.strerror +#elseif canImport(Glibc) +import Glibc +private let sys_bind = Glibc.bind +private let sys_listen = Glibc.listen +private let sys_close = Glibc.close +private let sys_accept = Glibc.accept +private let sys_strerror = Glibc.strerror +#endif + +final class AcceptedServerTests: GRPCTestCase { + private func withListener( + _ handle: (ListeningServer) throws -> Result + ) throws -> Result { + let server = try ListeningServer.bind(logger: self.logger) + + do { + return try handle(server) + } catch { + server.close() + throw error + } + } + + func testBasicCommunication() throws { + try self.withListener { listener in + let client = try GRPCChannelPool.with( + target: .host(listener.host, port: listener.port), + transportSecurity: .plaintext, + eventLoopGroup: .singletonMultiThreadedEventLoopGroup.next() + ) + defer { + try? client.close().wait() + } + + // Start an RPC to trigger a connect. + let echo = Echo_EchoNIOClient(channel: client) + let response = echo.get(.with { $0.text = "Hello!" }) + + // Now accept a connection and start the server. + let acceptedFD = try listener.accept() + + let server = try Server.insecure(group: .singletonMultiThreadedEventLoopGroup) + .withLogger(self.serverLogger) + .withServiceProviders([EchoProvider()]) + .fromAcceptedConnection(takingOwnershipOf: acceptedFD) + .wait() + + defer { try? server.close().wait() } + XCTAssertEqual(try response.response.wait().text, "Swift echo get: Hello!") + } + } + + #if canImport(NIOSSL) + func testBasicCommunicationWithTLS() throws { + try self.withListener { listener in + let client = try GRPCChannelPool.with( + target: .host(listener.host, port: listener.port), + transportSecurity: .tls( + .makeClientConfigurationBackedByNIOSSL( + trustRoots: .certificates([SampleCertificate.ca.certificate]), + hostnameOverride: "localhost" + ) + ), + eventLoopGroup: .singletonMultiThreadedEventLoopGroup.next() + ) + defer { + try? client.close().wait() + } + + // Start an RPC to trigger a connect. + let echo = Echo_EchoNIOClient(channel: client) + let response = echo.get(.with { $0.text = "Hello!" }) + + // Now accept a connection and start the server. + let acceptedFD = try listener.accept() + + let server = try Server.usingTLSBackedByNIOSSL( + on: .singletonMultiThreadedEventLoopGroup, + certificateChain: [SampleCertificate.server.certificate], + privateKey: SamplePrivateKey.server + ) + .withTLS(trustRoots: .certificates([SampleCertificate.ca.certificate])) + .withLogger(self.serverLogger) + .withServiceProviders([EchoProvider()]) + .fromAcceptedConnection(takingOwnershipOf: acceptedFD) + .wait() + + defer { try? server.close().wait() } + XCTAssertEqual(try response.response.wait().text, "Swift echo get: Hello!") + } + } + #endif // canImport(NIOSSL) + + func testGracefulShutdownOfServer() throws { + try self.withListener { listener in + let group = MultiThreadedEventLoopGroup.singleton + + let client = try GRPCChannelPool.with( + target: .host(listener.host, port: listener.port), + transportSecurity: .plaintext, + eventLoopGroup: group.next() + ) + defer { + try? client.close().wait() + } + + // Start an RPC to trigger a connect. + let echo = Echo_EchoNIOClient(channel: client) + + let messages = NIOLockedValueBox<[String]>([]) + let update = echo.update { reply in + messages.withLockedValue({ $0.append(reply.text) }) + } + + // Now accept a connection and start the server. + let acceptedFD = try listener.accept() + + let server = try Server.insecure(group: group) + .withLogger(self.serverLogger) + .withServiceProviders([EchoProvider()]) + .fromAcceptedConnection(takingOwnershipOf: acceptedFD) + .wait() + defer { try? server.close().wait() } + + // Initial metadata indicates both peers know about the RPC + XCTAssertNoThrow(try update.initialMetadata.wait()) + + // Begin graceful shutdown; 'update' can complete, new RPCs should fail. + let shutdown = server.initiateGracefulShutdown() + + // Start a new RPC, it should fail. + let getResponse = echo.get(.with { $0.text = "Bye!" }) + XCTAssertThrowsError(try getResponse.response.wait()) + + // Update should still work. + update.sendMessage(.with { $0.text = "Hello!" }, promise: nil) + update.sendEnd(promise: nil) + XCTAssertEqual(try update.status.wait().code, .ok) + XCTAssertEqual(messages.withLockedValue { $0 }, ["Swift echo update (0): Hello!"]) + + XCTAssertNoThrow(try shutdown.wait()) + } + } +} + +struct ListeningServer { + private var fd: Int32 + + let port: Int + var host: String { "127.0.0.1" } + var logger: Logger + + private init(fd: Int32, port: Int, logger: Logger) { + self.fd = fd + self.port = port + self.logger = logger + } + + func accept() throws(SocketError) -> CInt { + self.logger.debug("Accepting new client connection") + let fd = try Self.acceptConnection(on: self.fd) + self.logger.debug("Accepted new connection", metadata: ["fd": "\(fd)"]) + return fd + } + + func close() { + self.logger.debug("Closing listener socket") + _ = sys_close(self.fd) + } + + static func bind(logger: Logger) throws(SocketError) -> Self { + let fd = try Self.makeListeningSocket() + let port = try Self.getListeningPort(for: fd) + let server = ListeningServer(fd: fd, port: port, logger: logger) + logger.info( + "Opened listening socket", + metadata: ["addr": "\(server.host):\(server.port)", "fd": "\(fd)"] + ) + return server + } + + enum SocketError: Error { + case creationFailed + case bindFailed + case listenFailed + case acceptFailed(String) + case getsocknameFailed + } + + private static func makeListeningSocket() throws(SocketError) -> CInt { + #if canImport(Darwin) + let sockfd = socket(AF_INET, SOCK_STREAM, 0) + #elseif canImport(Glibc) + let sockfd = socket(AF_INET, CInt(SOCK_STREAM.rawValue), 0) + #else + fatalError("Unsupported libc") + #endif + if sockfd == -1 { + throw .creationFailed + } + + // Allow address reuse + var yes = 1 + setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &yes, socklen_t(MemoryLayout.size)) + + var addr = sockaddr_in() + addr.sin_family = sa_family_t(AF_INET) + addr.sin_port = 0 + addr.sin_addr.s_addr = inet_addr("127.0.0.1") + + let bindResult = withUnsafePointer(to: &addr) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + sys_bind(sockfd, $0, socklen_t(MemoryLayout.size)) + } + } + + if bindResult == -1 { + _ = sys_close(sockfd) + throw .bindFailed + } + + if sys_listen(sockfd, 5) == -1 { + _ = sys_close(sockfd) + throw .listenFailed + } + + return sockfd + } + + private static func getListeningPort( + for listener: CInt, + ) throws(SocketError) -> Int { + var address = sockaddr_in() + var addressLength = socklen_t(MemoryLayout.size) + + let getsocknameResult = withUnsafeMutablePointer(to: &address) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + getsockname(listener, $0, &addressLength) + } + } + + if getsocknameResult == 0 { + return Int(UInt16(bigEndian: address.sin_port)) + } else { + _ = sys_close(listener) + throw .getsocknameFailed + } + } + + private static func acceptConnection( + on listener: CInt, + ) throws(SocketError) -> CInt { + var clientAddress = sockaddr_in() + var clientAddressLength = socklen_t(MemoryLayout.size) + + let clientSocket = withUnsafeMutablePointer(to: &clientAddress) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + sys_accept(listener, $0, &clientAddressLength) + } + } + + if clientSocket == -1 { + throw .acceptFailed(String(cString: sys_strerror(errno)!)) + } else { + return clientSocket + } + } +}