Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/dc-leak
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
patch type="fixed" "Memory leaks in data channel cancellation code"
31 changes: 16 additions & 15 deletions Sources/LiveKit/Agent/Session.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ open class Session: ObservableObject {

// MARK: - Internal state

private var waitForAgentTask: Task<Void, Swift.Error>?
private var tasks = Set<AnyTaskCancellable>()
private var waitForAgentTask: AnyTaskCancellable?

// MARK: - Init

Expand Down Expand Up @@ -193,17 +194,10 @@ open class Session: ObservableObject {
receivers: receivers)
}

deinit {
waitForAgentTask?.cancel()
}

private func observe(room: Room) {
Task { [weak self] in
for try await _ in room.changes {
guard let self else { return }
updateAgent(in: room)
}
}
Task.observingOnMainActor(room.changes, by: self) { observer, _ in
observer.updateAgent(in: room)
}.store(in: &tasks)
}

private func updateAgent(in room: Room) {
Expand All @@ -221,18 +215,25 @@ open class Session: ObservableObject {
}

private func observe(receivers: [any MessageReceiver]) {
let (stream, continuation) = AsyncStream.makeStream(of: ReceivedMessage.self)

// Multiple producers → single stream
for receiver in receivers {
Task { [weak self] in
do {
for await message in try await receiver.messages() {
guard let self else { return }
messagesDict.updateValue(message, forKey: message.id)
continuation.yield(message)
}
} catch {
self?.error = .receiver(error)
}
}
}.cancellable().store(in: &tasks)
}

// Single consumer
Task.observingOnMainActor(stream, by: self) { owner, message in
owner.messagesDict.updateValue(message, forKey: message.id)
}.store(in: &tasks)
}

// MARK: - Lifecycle
Expand Down Expand Up @@ -278,7 +279,7 @@ open class Session: ObservableObject {
if isConnected, !agent.isConnected {
agent.failed(error: .timeout)
}
}
}.cancellable()
}
} catch {
self.error = .connection(error)
Expand Down
31 changes: 20 additions & 11 deletions Sources/LiveKit/Broadcast/IPC/BroadcastUploader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ final class BroadcastUploader: Sendable, Loggable {
private struct State {
var isUploadingImage = false
var shouldUploadAudio = false
var messageLoopTask: AnyTaskCancellable?
}

private let state = StateSync(State())
Expand All @@ -40,8 +41,19 @@ final class BroadcastUploader: Sendable, Loggable {

/// Creates an uploader with an open connection to another process.
init(socketPath: SocketPath) async throws {
channel = try await IPCChannel(connectingTo: socketPath)
Task { try await handleIncomingMessages() }
let channel = try await IPCChannel(connectingTo: socketPath)
self.channel = channel

let messageLoopTask = Task.observing(channel.incomingMessages(BroadcastIPCHeader.self), by: self) { observer, message in
observer.processMessageHeader(message.0)
} onFailure: { observer, error in
observer.log("IPCChannel returned error: \(error)")
}
state.mutate { $0.messageLoopTask = messageLoopTask }
}

deinit {
close()
}

/// Whether or not the connection to the receiver has been closed.
Expand Down Expand Up @@ -92,15 +104,12 @@ final class BroadcastUploader: Sendable, Loggable {
}
}

private func handleIncomingMessages() async throws {
for try await (header, _) in channel.incomingMessages(BroadcastIPCHeader.self) {
switch header {
case let .wantsAudio(wantsAudio):
state.mutate { $0.shouldUploadAudio = wantsAudio }
default:
log("Unhandled incoming message: \(header)", .debug)
continue
}
private func processMessageHeader(_ header: BroadcastIPCHeader) {
switch header {
case let .wantsAudio(wantsAudio):
state.mutate { $0.shouldUploadAudio = wantsAudio }
default:
log("Unhandled incoming message: \(header)", .debug)
}
}
}
Expand Down
111 changes: 55 additions & 56 deletions Sources/LiveKit/Core/DataChannelPair.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,19 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {
guard let lossy, let reliable else { return false }
return reliable.readyState == .open && lossy.readyState == .open
}
}

var eventContinuation: AsyncStream<ChannelEvent>.Continuation?
private struct Buffers: Sendable {
var lossyBuffer = SendBuffer()
var reliableBuffer = SendBuffer()
var reliableRetryBuffer = RetryBuffer(minAmount: DataChannelPair.reliableRetryAmount)
}

private let _state: StateSync<State>

private let eventContinuation: AsyncStream<ChannelEvent>.Continuation
private var eventLoopTask: AnyTaskCancellable?

fileprivate enum ChannelKind {
case lossy, reliable
}
Expand Down Expand Up @@ -133,56 +140,47 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {

// MARK: - Event handling

private func handleEvents(
events: AsyncStream<ChannelEvent>
) async {
var lossyBuffer = SendBuffer()
var reliableBuffer = SendBuffer()

var reliableRetryBuffer = RetryBuffer(minAmount: Self.reliableRetryAmount)

for await event in events {
switch event.detail {
case let .publishData(request):
switch event.channelKind {
case .lossy: lossyBuffer.enqueue(request)
case .reliable: reliableBuffer.enqueue(request)
}
case let .publishedData(request):
switch event.channelKind {
case .lossy: ()
case .reliable: reliableRetryBuffer.enqueue(request)
}
case let .bufferedAmountChanged(amount):
switch event.channelKind {
case .lossy:
updateTarget(buffer: &lossyBuffer, newAmount: amount)
case .reliable:
updateTarget(buffer: &reliableBuffer, newAmount: amount)
reliableRetryBuffer.trim(toAmount: amount)
}
case let .retryRequested(lastSeq):
switch event.channelKind {
case .lossy: ()
case .reliable: retry(buffer: &reliableRetryBuffer, from: lastSeq)
}
private func processEvent(_ event: ChannelEvent, buffers: inout Buffers) {
switch event.detail {
case let .publishData(request):
switch event.channelKind {
case .lossy: buffers.lossyBuffer.enqueue(request)
case .reliable: buffers.reliableBuffer.enqueue(request)
}

case let .publishedData(request):
switch event.channelKind {
case .lossy: ()
case .reliable: buffers.reliableRetryBuffer.enqueue(request)
}
case let .bufferedAmountChanged(amount):
switch event.channelKind {
case .lossy:
processSendQueue(
threshold: Self.lossyLowThreshold,
buffer: &lossyBuffer,
kind: .lossy
)
updateTarget(buffer: &buffers.lossyBuffer, newAmount: amount)
case .reliable:
processSendQueue(
threshold: Self.reliableLowThreshold,
buffer: &reliableBuffer,
kind: .reliable
)
updateTarget(buffer: &buffers.reliableBuffer, newAmount: amount)
buffers.reliableRetryBuffer.trim(toAmount: amount)
}
case let .retryRequested(lastSeq):
switch event.channelKind {
case .lossy: ()
case .reliable: retry(buffer: &buffers.reliableRetryBuffer, from: lastSeq)
}
}

switch event.channelKind {
case .lossy:
processSendQueue(
threshold: Self.lossyLowThreshold,
buffer: &buffers.lossyBuffer,
kind: .lossy
)
case .reliable:
processSendQueue(
threshold: Self.reliableLowThreshold,
buffer: &buffers.reliableBuffer,
kind: .reliable
)
}
}

private func channel(for kind: ChannelKind) -> LKRTCDataChannel? {
Expand Down Expand Up @@ -215,7 +213,7 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {
request.continuation?.resume()

let event = ChannelEvent(channelKind: kind, detail: .publishedData(request))
_state.eventContinuation?.yield(event)
eventContinuation.yield(event)
}
}

Expand Down Expand Up @@ -244,7 +242,7 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {
assert(request.continuation == nil, "Continuation may fire multiple times while retrying causing crash")
if request.sequence > lastSeq {
let event = ChannelEvent(channelKind: .reliable, detail: .publishData(request))
_state.eventContinuation?.yield(event)
eventContinuation.yield(event)
}
}
}
Expand All @@ -261,13 +259,14 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {
if let delegate {
delegates.add(delegate: delegate)
}

let (eventStream, continuation) = AsyncStream.makeStream(of: ChannelEvent.self)
eventContinuation = continuation

super.init()

Task {
let eventStream = AsyncStream<ChannelEvent> { continuation in
_state.mutate { $0.eventContinuation = continuation }
}
await handleEvents(events: eventStream)
eventLoopTask = Task.observing(eventStream, by: self, withMutableState: Buffers()) { observer, event, buffers in
observer.processEvent(event, buffers: &buffers)
}
}

Expand Down Expand Up @@ -337,7 +336,7 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {
channelKind: ChannelKind(packet.kind), // TODO: field is deprecated
detail: .publishData(request)
)
_state.eventContinuation?.yield(event)
eventContinuation.yield(event)
}
}

Expand Down Expand Up @@ -367,7 +366,7 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {

func retryReliable(lastSequence: UInt32) {
let event = ChannelEvent(channelKind: .reliable, detail: .retryRequested(lastSequence))
_state.eventContinuation?.yield(event)
eventContinuation.yield(event)
}

// MARK: - Sync state
Expand Down Expand Up @@ -398,7 +397,7 @@ class DataChannelPair: NSObject, @unchecked Sendable, Loggable {
private static let reliableReceivedStateTTL: TimeInterval = 30

deinit {
_state.eventContinuation?.finish()
eventContinuation.finish()
}
}

Expand All @@ -410,7 +409,7 @@ extension DataChannelPair: LKRTCDataChannelDelegate {
channelKind: dataChannel.kind,
detail: .bufferedAmountChanged(amount)
)
_state.eventContinuation?.yield(event)
eventContinuation.yield(event)
}

func dataChannelDidChangeState(_: LKRTCDataChannel) {
Expand Down
23 changes: 10 additions & 13 deletions Sources/LiveKit/Core/SignalClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ actor SignalClient: Loggable {
var connectionState: ConnectionState = .disconnected
var disconnectError: LiveKitError?
var socket: WebSocket?
var messageLoopTask: Task<Void, Never>?
var messageLoopTask: AnyTaskCancellable?
var lastJoinResponse: Livekit_JoinResponse?
var rtt: Int64 = 0
}
Expand Down Expand Up @@ -139,17 +139,12 @@ actor SignalClient: Loggable {
token: token,
connectOptions: connectOptions)

let task = Task.detached {
self.log("Did enter WebSocket message loop...")
do {
for try await message in socket {
await self._onWebSocketMessage(message: message)
}
} catch {
await self.cleanUp(withError: error)
}
let messageLoopTask = Task.observing(socket, by: self) { observer, message in
await observer.onWebSocketMessage(message)
} onFailure: { observer, error in
await observer.cleanUp(withError: error)
}
_state.mutate { $0.messageLoopTask = task }
_state.mutate { $0.messageLoopTask = messageLoopTask }

let connectResponse = try await _connectResponseCompleter.wait()
// Check cancellation after received join response
Expand Down Expand Up @@ -207,14 +202,16 @@ actor SignalClient: Loggable {
}

func cleanUp(withError disconnectError: Error? = nil) async {
if disconnectError is CancellationError { return }
if let lkError = disconnectError as? LiveKitError, lkError.type == .cancelled { return }

log("withError: \(String(describing: disconnectError))")

// Cancel ping/pong timers immediately to prevent stale timers from affecting future connections
_pingIntervalTimer.cancel()
_pingTimeoutTimer.cancel()

_state.mutate {
$0.messageLoopTask?.cancel()
$0.messageLoopTask = nil
$0.socket?.close()
$0.socket = nil
Expand Down Expand Up @@ -247,7 +244,7 @@ private extension SignalClient {
await _requestQueue.processIfResumed(request, elseEnqueue: request.canBeQueued())
}

func _onWebSocketMessage(message: URLSessionWebSocketTask.Message) async {
func onWebSocketMessage(_ message: URLSessionWebSocketTask.Message) async {
let response: Livekit_SignalResponse? = switch message {
case let .data(data): try? Livekit_SignalResponse(serializedBytes: data)
case let .string(string): try? Livekit_SignalResponse(jsonString: string)
Expand Down
Loading
Loading