diff --git a/Sources/NIOCore/Channel.swift b/Sources/NIOCore/Channel.swift index d86362f6e9..0585916ec6 100644 --- a/Sources/NIOCore/Channel.swift +++ b/Sources/NIOCore/Channel.swift @@ -304,6 +304,26 @@ extension ChannelCore { data.forceAs() } + /// Unwraps the given `NIOAny` as a specific concrete type. + /// + /// This method is intended for use when writing custom `ChannelCore` implementations. + /// This can safely be called in methods like `write0` to extract data from the `NIOAny` + /// provided in those cases. + /// + /// Note that if the unwrap fails, this will cause a runtime trap. `ChannelCore` + /// implementations should be concrete about what types they support writing. If multiple + /// types are supported, consider using a tagged union to store the type information like + /// NIO's own `IOData`, which will minimise the amount of runtime type checking. + /// + /// - Parameters: + /// - data: The `NIOAny` to unwrap. + /// - as: The type to extract from the `NIOAny`. + /// - Returns: The content of the `NIOAny`. + @inlinable + public static func unwrapData(_ data: NIOAny, as: T.Type = T.self) -> T { + data.forceAs() + } + /// Attempts to unwrap the given `NIOAny` as a specific concrete type. /// /// This method is intended for use when writing custom `ChannelCore` implementations. @@ -326,6 +346,28 @@ extension ChannelCore { data.tryAs() } + /// Attempts to unwrap the given `NIOAny` as a specific concrete type. + /// + /// This method is intended for use when writing custom `ChannelCore` implementations. + /// This can safely be called in methods like `write0` to extract data from the `NIOAny` + /// provided in those cases. + /// + /// If the unwrap fails, this will return `nil`. `ChannelCore` implementations should almost + /// always support only one runtime type, so in general they should avoid using this and prefer + /// using `unwrapData` instead. This method exists for rare use-cases where tolerating type + /// mismatches is acceptable. + /// + /// - Parameters: + /// - data: The `NIOAny` to unwrap. + /// - as: The type to extract from the `NIOAny`. + /// - Returns: The content of the `NIOAny`, or `nil` if the type is incorrect. + /// - warning: If you are implementing a `ChannelCore`, you should use `unwrapData` unless you + /// are doing something _extremely_ unusual. + @inlinable + public static func tryUnwrapData(_ data: NIOAny, as: T.Type = T.self) -> T? { + data.tryAs() + } + /// Removes the `ChannelHandler`s from the `ChannelPipeline` belonging to `channel`, and /// closes that `ChannelPipeline`. /// diff --git a/Sources/NIOPosix/BaseSocketChannel.swift b/Sources/NIOPosix/BaseSocketChannel.swift index d928416451..d473d616e8 100644 --- a/Sources/NIOPosix/BaseSocketChannel.swift +++ b/Sources/NIOPosix/BaseSocketChannel.swift @@ -223,7 +223,7 @@ private struct SocketChannelLifecycleManager { /// For this reason, `BaseSocketChannel` exists to provide a common core implementation of /// the `SelectableChannel` protocol. It uses a number of private functions to provide hooks /// for subclasses to implement the specific logic to handle their writes and reads. -class BaseSocketChannel: SelectableChannel, ChannelCore, @unchecked Sendable { +class BaseSocketChannel: SelectableChannel, ChannelCore, @unchecked Sendable { typealias SelectableType = SocketType.SelectableType struct AddressCache { @@ -472,7 +472,7 @@ class BaseSocketChannel: SelectableChannel, Chan } /// Buffer a write in preparation for a flush. - func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise?) { + func bufferPendingWrite(data: WriteType, promise: EventLoopPromise?) { fatalError("this must be overridden by sub class") } @@ -732,6 +732,10 @@ class BaseSocketChannel: SelectableChannel, Chan } } + internal func unwrapAsWriteType(_ data: NIOAny) -> WriteType { + return Self.unwrapData(data, as: WriteType.self) + } + public final func write0(_ data: NIOAny, promise: EventLoopPromise?) { self.eventLoop.assertInEventLoop() @@ -741,7 +745,8 @@ class BaseSocketChannel: SelectableChannel, Chan return } - bufferPendingWrite(data: data, promise: promise) + let data = self.unwrapAsWriteType(data) + self.bufferPendingWrite(data: data, promise: promise) } private func registerForWritable() { @@ -1401,12 +1406,14 @@ class BaseSocketChannel: SelectableChannel, Chan } extension BaseSocketChannel { + typealias BaseSocketChannelType = BaseSocketChannel + public struct SynchronousOptions: NIOSynchronousChannelOptions { @usableFromInline // should be private - internal let _channel: BaseSocketChannel + internal let _channel: BaseSocketChannelType @inlinable // should be fileprivate - internal init(_channel channel: BaseSocketChannel) { + internal init(_channel channel: BaseSocketChannelType) { self._channel = channel } diff --git a/Sources/NIOPosix/BaseStreamSocketChannel.swift b/Sources/NIOPosix/BaseStreamSocketChannel.swift index 9c915b1a29..df608173a1 100644 --- a/Sources/NIOPosix/BaseStreamSocketChannel.swift +++ b/Sources/NIOPosix/BaseStreamSocketChannel.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import NIOCore -class BaseStreamSocketChannel: BaseSocketChannel, @unchecked Sendable { +class BaseStreamSocketChannel: BaseSocketChannel, @unchecked Sendable { internal var connectTimeoutScheduled: Optional> private var allowRemoteHalfClosure: Bool = false private var inputShutdown: Bool = false @@ -288,14 +288,12 @@ class BaseStreamSocketChannel: BaseSocketChannel super.read0() } - final override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise?) { + final override func bufferPendingWrite(data: IOData, promise: EventLoopPromise?) { if self.outputShutdown { promise?.fail(ChannelError._outputClosed) return } - let data = self.unwrapData(data, as: IOData.self) - if !self.pendingWrites.add(data: data, promise: promise) { self.pipeline.syncOperations.fireChannelWritabilityChanged() } diff --git a/Sources/NIOPosix/SocketChannel.swift b/Sources/NIOPosix/SocketChannel.swift index 504c2ad834..54ce9cf545 100644 --- a/Sources/NIOPosix/SocketChannel.swift +++ b/Sources/NIOPosix/SocketChannel.swift @@ -196,7 +196,7 @@ final class SocketChannel: BaseStreamSocketChannel, @unchecked Sendable /// A `Channel` for a server socket. /// /// - Note: All operations on `ServerSocketChannel` are thread-safe. -final class ServerSocketChannel: BaseSocketChannel, @unchecked Sendable { +final class ServerSocketChannel: BaseSocketChannel, @unchecked Sendable { private var backlog: Int32 = 128 private let group: EventLoopGroup @@ -422,8 +422,8 @@ final class ServerSocketChannel: BaseSocketChannel, @unchecked Sen false } - override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise?) { - promise?.fail(ChannelError._operationUnsupported) + override func bufferPendingWrite(data: Void, promise: EventLoopPromise?) { + promise?.fail(ChannelError.operationUnsupported) } override func markFlushPoint() { @@ -459,12 +459,19 @@ final class ServerSocketChannel: BaseSocketChannel, @unchecked Sen promise?.fail(ChannelError._operationUnsupported) } } + + override func unwrapAsWriteType(_ data: NIOAny) -> () {} +} + +enum DatagramWriteType { + case addressed(AddressedEnvelope) + case unaddressed(ByteBuffer) } /// A channel used with datagram sockets. /// /// Currently, it does not support connected mode which is well worth adding. -final class DatagramChannel: BaseSocketChannel, @unchecked Sendable { +final class DatagramChannel: BaseSocketChannel, @unchecked Sendable { private var reportExplicitCongestionNotifications = false private var receivePacketInfo = false @@ -894,6 +901,14 @@ final class DatagramChannel: BaseSocketChannel, @unchecked Sendable { } } + override func unwrapAsWriteType(_ data: NIOAny) -> DatagramWriteType { + if let envelope = self.tryUnwrapData(data, as: AddressedEnvelope.self) { + return .addressed(envelope) + } else { + return .unaddressed(Self.unwrapData(data, as: ByteBuffer.self)) + } + } + /// Buffer a write in preparation for a flush. /// /// When the channel is unconnected, `data` _must_ be of type `AddressedEnvelope`. @@ -902,14 +917,13 @@ final class DatagramChannel: BaseSocketChannel, @unchecked Sendable { /// `AddressedEnvelope` to allow users to provide protocol control messages via /// `AddressedEnvelope.metadata`. In this case, `AddressedEnvelope.remoteAddress` _must_ match /// the address of the connected peer. - override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise?) { - if let envelope = self.tryUnwrapData(data, as: AddressedEnvelope.self) { - return bufferPendingAddressedWrite(envelope: envelope, promise: promise) + override func bufferPendingWrite(data: DatagramWriteType, promise: EventLoopPromise?) { + switch data { + case .addressed(let addressedBytes): + return self.bufferPendingAddressedWrite(envelope: addressedBytes, promise: promise) + case .unaddressed(let bytes): + return self.bufferPendingUnaddressedWrite(data: bytes, promise: promise) } - // If it's not an `AddressedEnvelope` then it must be a `ByteBuffer` so we let the common - // `unwrapData(_:as:)` throw the fatal error if it's some other type. - let data = self.unwrapData(data, as: ByteBuffer.self) - return bufferPendingUnaddressedWrite(data: data, promise: promise) } /// Buffer a write in preparation for a flush.