Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions Sources/SwiftOCADevice/OCP.1/Backend/CF/Ocp1CFController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import Foundation
import SocketAddress
@_spi(SwiftOCAPrivate)
import SwiftOCA
import Synchronization
import SystemPackage
#if canImport(Darwin)
import Darwin
Expand Down Expand Up @@ -91,19 +92,24 @@ actor Ocp1CFStreamController: Ocp1CFControllerPrivate, CustomStringConvertible {

private let _messages: AsyncThrowingStream<Ocp1MessageList, Error>
private let _messagesContinuation: AsyncThrowingStream<Ocp1MessageList, Error>.Continuation
private let socket: _CFSocketWrapper
private let _socket: Mutex<_CFSocketWrapper?>
let notificationSocket: _CFSocketWrapper

nonisolated var description: String {
"\(type(of: self))(socket: \(socket))"
let socket = socket
return "\(type(of: self))(socket: \(socket != nil ? String(describing: socket!) : "<disconnected>"))"
}

private nonisolated var socket: _CFSocketWrapper? {
_socket.withLock { $0 }
}

init(
endpoint: Ocp1CFStreamDeviceEndpoint,
socket: _CFSocketWrapper,
notificationSocket: _CFSocketWrapper
) async {
self.socket = socket
_socket = .init(socket)
self.notificationSocket = notificationSocket
peerAddress = socket.peerAddress!

Expand All @@ -118,10 +124,10 @@ actor Ocp1CFStreamController: Ocp1CFControllerPrivate, CustomStringConvertible {
connectionPrefix = OcaTcpConnectionPrefix
}

receiveMessageTask = Task { [weak self, socket] in
receiveMessageTask = Task { [weak self] in
do {
repeat {
guard !Task.isCancelled else { break }
guard !Task.isCancelled, let socket = self?.socket else { break }
let messages = try await OcaDevice
.receiveMessages { try await Array(socket.read(count: $0)) }
self?._messagesContinuation.yield(messages)
Expand All @@ -132,16 +138,23 @@ actor Ocp1CFStreamController: Ocp1CFControllerPrivate, CustomStringConvertible {
}
}

func close() async throws {
// don't close the socket, it will be closed when last reference is released
private func closeSocket() {
_socket.withLock { $0 = nil }
}

func close() async throws {
keepAliveTask?.cancel()
keepAliveTask = nil

receiveMessageTask?.cancel()
receiveMessageTask = nil
if let receiveMessageTask {
receiveMessageTask.cancel()
_ = await receiveMessageTask.result
self.receiveMessageTask = nil
}

_messagesContinuation.finish()

closeSocket()
}

deinit {
Expand All @@ -157,6 +170,7 @@ actor Ocp1CFStreamController: Ocp1CFControllerPrivate, CustomStringConvertible {
}

func sendOcp1EncodedData(_ data: Data) async throws {
guard let socket else { throw Errno.badFileDescriptor }
_ = try await socket.write(data: data)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public import IORing

import SocketAddress
import SwiftOCA
import Synchronization
import struct SystemPackage.Errno

protocol Ocp1IORingControllerPrivate: Ocp1ControllerInternal,
Ocp1ControllerInternalLightweightNotifyingInternal, Actor,
Expand Down Expand Up @@ -82,19 +84,24 @@ actor Ocp1IORingStreamController: Ocp1IORingControllerPrivate, CustomStringConve

private let _messages: AsyncThrowingStream<Ocp1MessageList, Error>
private let _messagesContinuation: AsyncThrowingStream<Ocp1MessageList, Error>.Continuation
private let socket: Socket
private let _socket: Mutex<Socket?>
let notificationSocket: Socket

nonisolated var description: String {
"\(type(of: self))(socket: \(socket))"
let socket = socket
return "\(type(of: self))(socket: \(socket != nil ? String(describing: socket!) : "<disconnected>"))"
}

private nonisolated var socket: Socket? {
_socket.withLock { $0 }
}

init(
endpoint: Ocp1IORingStreamDeviceEndpoint,
socket: Socket,
notificationSocket: Socket
) async throws {
self.socket = socket
_socket = .init(socket)
self.notificationSocket = notificationSocket
self.endpoint = endpoint

Expand All @@ -103,17 +110,17 @@ actor Ocp1IORingStreamController: Ocp1IORingControllerPrivate, CustomStringConve
throwing: Error.self
)

peerAddress = try AnySocketAddress(self.socket.peerAddress)
peerAddress = try AnySocketAddress(socket.peerAddress)
if peerAddress.family == AF_LOCAL {
connectionPrefix = OcaLocalConnectionPrefix
} else {
connectionPrefix = OcaTcpConnectionPrefix
}

receiveMessageTask = Task { [weak self, socket] in
receiveMessageTask = Task { [weak self] in
do {
repeat {
guard !Task.isCancelled else { break }
guard !Task.isCancelled, let socket = self?.socket else { break }
let messages = try await OcaDevice.receiveMessages { try await socket.read(
count: $0,
awaitingAllRead: true
Expand All @@ -126,16 +133,23 @@ actor Ocp1IORingStreamController: Ocp1IORingControllerPrivate, CustomStringConve
}
}

func close() {
// don't close the socket, it will be closed when last reference is released
private func closeSocket() {
_socket.withLock { $0 = nil }
}

func close() async {
keepAliveTask?.cancel()
keepAliveTask = nil

receiveMessageTask?.cancel()
receiveMessageTask = nil
if let receiveMessageTask {
receiveMessageTask.cancel()
_ = await receiveMessageTask.result
self.receiveMessageTask = nil
}

_messagesContinuation.finish()

closeSocket()
}

deinit {
Expand All @@ -151,6 +165,7 @@ actor Ocp1IORingStreamController: Ocp1IORingControllerPrivate, CustomStringConve
}

func sendOcp1EncodedData(_ data: Data) async throws {
guard let socket else { throw Errno.badFileDescriptor }
_ = try await socket.write(
[UInt8](data),
count: data.count,
Expand All @@ -163,7 +178,7 @@ actor Ocp1IORingStreamController: Ocp1IORingControllerPrivate, CustomStringConve
}

nonisolated var identifier: String {
(try? socket.peerName) ?? "unknown"
(try? socket?.peerName) ?? "unknown"
}
}

Expand Down
Loading