Skip to content

Commit c6f3ba8

Browse files
andrewhearddaymxn
authored andcommitted
Add AsyncWebSocket wrapper for URLSessionWebSocketTask
1 parent 95e1908 commit c6f3ba8

File tree

3 files changed

+130
-34
lines changed

3 files changed

+130
-34
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
final class AsyncWebSocket: NSObject, @unchecked Sendable, URLSessionWebSocketDelegate {
18+
private let webSocketTask: URLSessionWebSocketTask
19+
private let stream: AsyncThrowingStream<URLSessionWebSocketTask.Message, Error>
20+
private let continuation: AsyncThrowingStream<URLSessionWebSocketTask.Message, Error>.Continuation
21+
private var continuationFinished = false
22+
private let continuationLock = NSLock()
23+
24+
private var _isConnected = false
25+
private let isConnectedLock = NSLock()
26+
private(set) var isConnected: Bool {
27+
get { isConnectedLock.withLock { _isConnected } }
28+
set { isConnectedLock.withLock { _isConnected = newValue } }
29+
}
30+
31+
init(urlSession: URLSession = GenAIURLSession.default, urlRequest: URLRequest) {
32+
webSocketTask = urlSession.webSocketTask(with: urlRequest)
33+
(stream, continuation) = AsyncThrowingStream<URLSessionWebSocketTask.Message, Error>
34+
.makeStream()
35+
}
36+
37+
deinit {
38+
webSocketTask.cancel(with: .goingAway, reason: nil)
39+
}
40+
41+
func connect() -> AsyncThrowingStream<URLSessionWebSocketTask.Message, Error> {
42+
webSocketTask.resume()
43+
isConnected = true
44+
startReceiving()
45+
return stream
46+
}
47+
48+
func disconnect() {
49+
webSocketTask.cancel(with: .goingAway, reason: nil)
50+
isConnected = false
51+
continuationLock.withLock {
52+
self.continuation.finish()
53+
self.continuationFinished = true
54+
}
55+
}
56+
57+
func send(_ message: URLSessionWebSocketTask.Message) async throws {
58+
// TODO: Throw error if socket already closed
59+
try await webSocketTask.send(message)
60+
}
61+
62+
private func startReceiving() {
63+
Task {
64+
while !Task.isCancelled && self.webSocketTask.isOpen && self.isConnected {
65+
let message = try await webSocketTask.receive()
66+
// TODO: Check continuationFinished before yielding. Use the same thread for NSLock.
67+
continuation.yield(message)
68+
}
69+
}
70+
}
71+
72+
func urlSession(_ session: URLSession,
73+
webSocketTask: URLSessionWebSocketTask,
74+
didCloseWith closeCode: URLSessionWebSocketTask.CloseCode,
75+
reason: Data?) {
76+
continuationLock.withLock {
77+
guard !continuationFinished else { return }
78+
continuation.finish()
79+
continuationFinished = true
80+
}
81+
}
82+
}
83+
84+
private extension URLSessionWebSocketTask {
85+
var isOpen: Bool {
86+
return closeCode == .invalid
87+
}
88+
}
89+
90+
struct WebSocketClosedError: Error, Sendable, CustomNSError {
91+
let closeCode: URLSessionWebSocketTask.CloseCode
92+
let closeReason: String
93+
94+
init(closeCode: URLSessionWebSocketTask.CloseCode, closeReason: Data?) {
95+
self.closeCode = closeCode
96+
self.closeReason = closeReason
97+
.flatMap { String(data: $0, encoding: .utf8) } ?? "Unknown reason."
98+
}
99+
100+
var errorCode: Int { closeCode.rawValue }
101+
102+
var errorUserInfo: [String: Any] {
103+
[
104+
NSLocalizedDescriptionKey: "WebSocket closed with code \(closeCode.rawValue). Reason: \(closeReason)",
105+
]
106+
}
107+
}

FirebaseAI/Sources/Types/Internal/Live/BidiGenerateContentRealtimeInput.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import Foundation
3131
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
3232
struct BidiGenerateContentRealtimeInput: Encodable {
3333
/// These form the realtime audio input stream.
34-
let audio: Data?
34+
let audio: InlineData?
3535

3636
/// Indicates that the audio stream has ended, e.g. because the microphone was
3737
/// turned off.

FirebaseAI/Sources/Types/Public/Live/LiveSession.swift

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Foundation
1919
public final class LiveSession: Sendable {
2020
let modelResourceName: String
2121
let generationConfig: LiveGenerationConfig?
22-
let webSocket: URLSessionWebSocketTask
22+
let webSocket: AsyncWebSocket
2323

2424
public let responses: AsyncThrowingStream<BidiGenerateContentServerMessage, Error>
2525
private let responseContinuation: AsyncThrowingStream<BidiGenerateContentServerMessage, Error>
@@ -34,12 +34,12 @@ public final class LiveSession: Sendable {
3434
urlSession: URLSession) {
3535
self.modelResourceName = modelResourceName
3636
self.generationConfig = generationConfig
37-
webSocket = urlSession.webSocketTask(with: url)
37+
webSocket = AsyncWebSocket(urlSession: urlSession, urlRequest: URLRequest(url: url))
3838
(responses, responseContinuation) = AsyncThrowingStream.makeStream()
3939
}
4040

4141
deinit {
42-
webSocket.cancel(with: .goingAway, reason: nil)
42+
webSocket.disconnect()
4343
}
4444

4545
public func sendMessage(_ message: String) async throws {
@@ -51,40 +51,29 @@ public final class LiveSession: Sendable {
5151
}
5252

5353
func openConnection() {
54-
webSocket.resume()
55-
// TODO: Verify that this task gets cancelled on deinit
5654
Task {
57-
await startEventLoop()
58-
}
59-
}
60-
61-
private func startEventLoop() async {
62-
defer {
63-
webSocket.cancel(with: .goingAway, reason: nil)
64-
}
65-
66-
do {
67-
try await sendSetupMessage()
68-
69-
while !Task.isCancelled {
70-
let message = try await webSocket.receive()
71-
switch message {
72-
case let .string(string):
73-
print("Unexpected string response: \(string)")
74-
case let .data(data):
75-
let response = try jsonDecoder.decode(
76-
BidiGenerateContentServerMessage.self,
77-
from: data
78-
)
79-
responseContinuation.yield(response)
80-
@unknown default:
81-
print("Unknown message received")
55+
do {
56+
let stream = webSocket.connect()
57+
try await sendSetupMessage()
58+
for try await message in stream {
59+
switch message {
60+
case let .string(string):
61+
print("Unexpected string response: \(string)")
62+
case let .data(data):
63+
let response = try jsonDecoder.decode(
64+
BidiGenerateContentServerMessage.self,
65+
from: data
66+
)
67+
responseContinuation.yield(response)
68+
@unknown default:
69+
print("Unknown message received")
70+
}
8271
}
72+
} catch {
73+
responseContinuation.finish(throwing: error)
8374
}
84-
} catch {
85-
responseContinuation.finish(throwing: error)
75+
responseContinuation.finish()
8676
}
87-
responseContinuation.finish()
8877
}
8978

9079
private func sendSetupMessage() async throws {

0 commit comments

Comments
 (0)