15
15
import Foundation
16
16
17
17
// TODO: Extract most of this file into a service class similar to `GenerativeAIService`.
18
- public final class LiveSession : NSObject , URLSessionWebSocketDelegate , URLSessionTaskDelegate {
19
- private enum State {
20
- case notConnected
21
- case connecting
22
- case setupSent
23
- case ready
24
- case closed
25
- }
26
-
27
- private enum WebSocketError : Error {
28
- case connectionClosed
29
- }
30
-
18
+ public final class LiveSession : Sendable {
31
19
let modelResourceName : String
32
20
let generationConfig : LiveGenerationConfig ?
33
21
let webSocket : URLSessionWebSocketTask
34
22
35
- // TODO: Refactor this property, potentially returning responses after `connect`.
36
23
public let responses : AsyncThrowingStream < BidiGenerateContentServerMessage , Error >
24
+ private let responseContinuation : AsyncThrowingStream < BidiGenerateContentServerMessage , Error >
25
+ . Continuation
37
26
38
- private var state : State = . notConnected
39
- private var pendingMessages : [ ( String , CheckedContinuation < Void , Error > ) ] = [ ]
40
27
private let jsonEncoder = JSONEncoder ( )
41
28
private let jsonDecoder = JSONDecoder ( )
42
29
43
- // TODO: Properly wrap callback code using `withCheckedContinuation` or similar.
44
- private let responseContinuation : AsyncThrowingStream < BidiGenerateContentServerMessage , Error >
45
- . Continuation
46
-
47
30
init ( modelResourceName: String ,
48
31
generationConfig: LiveGenerationConfig ? ,
49
32
url: URL ,
@@ -54,166 +37,61 @@ public final class LiveSession: NSObject, URLSessionWebSocketDelegate, URLSessio
54
37
( responses, responseContinuation) = AsyncThrowingStream . makeStream ( )
55
38
}
56
39
57
- func open( ) async throws {
58
- guard state == . notConnected else {
59
- print ( " Web socket is not in a valid state to be opened: \( state) " )
60
- return
61
- }
62
-
63
- state = . connecting
64
- webSocket. delegate = self
65
- webSocket. resume ( )
66
-
67
- print ( " Opening websocket " )
68
- }
69
-
70
- private func failPendingMessages( with error: Error ) {
71
- for (_, continuation) in pendingMessages {
72
- continuation. resume ( throwing: error)
73
- }
74
- pendingMessages. removeAll ( )
75
- responseContinuation. finish ( throwing: error)
40
+ deinit {
41
+ webSocket. cancel ( with: . goingAway, reason: nil )
76
42
}
77
43
78
- private func processPendingMessages( ) {
79
- for (message, continuation) in pendingMessages {
80
- Task {
81
- do {
82
- try await send ( message)
83
- continuation. resume ( )
84
- } catch {
85
- continuation. resume ( throwing: error)
86
- }
87
- }
88
- }
89
- pendingMessages. removeAll ( )
90
- }
91
-
92
- private func send( _ message: String ) async throws {
44
+ public func sendMessage( _ message: String ) async throws {
93
45
let content = ModelContent ( role: " user " , parts: [ message] )
94
46
let clientContent = BidiGenerateContentClientContent ( turns: [ content] , turnComplete: true )
95
47
let clientMessage = BidiGenerateContentClientMessage . clientContent ( clientContent)
96
48
let clientMessageData = try jsonEncoder. encode ( clientMessage)
97
- let clientMessageJSON = String ( data: clientMessageData, encoding: . utf8)
98
- print ( " Client Message JSON: \( clientMessageJSON) " )
99
49
try await webSocket. send ( . data( clientMessageData) )
100
- setReceiveHandler ( )
101
50
}
102
51
103
- public func sendMessage( _ message: String ) async throws {
104
- if state == . ready {
105
- try await send ( message)
106
- } else {
107
- try await withCheckedThrowingContinuation { continuation in
108
- pendingMessages. append ( ( message, continuation) )
109
- }
52
+ func openConnection( ) {
53
+ webSocket. resume ( )
54
+ // TODO: Verify that this task gets cancelled on deinit
55
+ Task {
56
+ await startEventLoop ( )
110
57
}
111
58
}
112
59
113
- public func urlSession( _ session: URLSession ,
114
- webSocketTask: URLSessionWebSocketTask ,
115
- didOpenWithProtocol protocol: String ? ) {
116
- print ( " Web Socket opened. " )
117
-
118
- guard state == . connecting else {
119
- print ( " Web socket is not in a valid state to be opened: \( state) " )
120
- return
60
+ private func startEventLoop( ) async {
61
+ defer {
62
+ webSocket. cancel ( with: . goingAway, reason: nil )
121
63
}
122
64
123
65
do {
124
- let setup = BidiGenerateContentSetup (
125
- model: modelResourceName, generationConfig: generationConfig
126
- )
127
- let message = BidiGenerateContentClientMessage . setup ( setup)
128
- let messageData = try jsonEncoder. encode ( message)
129
- let messageJSON = String ( data: messageData, encoding: . utf8)
130
- print ( " JSON: \( messageJSON) " )
131
- webSocketTask. send ( . data( messageData) ) { error in
132
- if let error {
133
- print ( " Send Error: \( error) " )
134
- self . state = . closed
135
- self . failPendingMessages ( with: error)
136
- return
137
- }
138
-
139
- self . state = . setupSent
140
- self . setReceiveHandler ( )
141
- }
142
- } catch {
143
- print ( error)
144
- state = . closed
145
- failPendingMessages ( with: error)
146
- }
147
- }
148
-
149
- public func urlSession( _ session: URLSession ,
150
- webSocketTask: URLSessionWebSocketTask ,
151
- didCloseWith closeCode: URLSessionWebSocketTask . CloseCode ,
152
- reason: Data ? ) {
153
- print ( " Web Socket closed. " )
154
- state = . closed
155
- failPendingMessages ( with: WebSocketError . connectionClosed)
156
- responseContinuation. finish ( )
157
- }
158
-
159
- func setReceiveHandler( ) {
160
- guard state == . setupSent || state == . ready else {
161
- print ( " Web socket is not in a valid state to receive messages: \( state) " )
162
- return
163
- }
66
+ try await sendSetupMessage ( )
164
67
165
- webSocket. receive { result in
166
- do {
167
- let message = try result. get ( )
68
+ while !Task. isCancelled {
69
+ let message = try await webSocket. receive ( )
168
70
switch message {
169
71
case let . string( string) :
170
72
print ( " Unexpected string response: \( string) " )
171
- self . setReceiveHandler ( )
172
73
case let . data( data) :
173
- let response = try self . jsonDecoder. decode (
74
+ let response = try jsonDecoder. decode (
174
75
BidiGenerateContentServerMessage . self,
175
76
from: data
176
77
)
177
- let responseJSON = String ( data: data, encoding: . utf8)
178
-
179
- switch response. messageType {
180
- case . setupComplete:
181
- print ( " Setup Complete: \( responseJSON) " )
182
- self . state = . ready
183
- self . processPendingMessages ( )
184
- case . serverContent:
185
- print ( " Server Content: \( responseJSON) " )
186
- case . toolCall:
187
- // TODO: Tool calls not yet implemented
188
- print ( " Tool Call: \( responseJSON) " )
189
- case . toolCallCancellation:
190
- // TODO: Tool call cancellation not yet implemented
191
- print ( " Tool Call Cancellation: \( responseJSON) " )
192
- case let . goAway( goAway) :
193
- if let timeLeft = goAway. timeLeft {
194
- print ( " Server will disconnect in \( timeLeft) seconds. " )
195
- } else {
196
- print ( " Server will disconnect soon. " )
197
- }
198
- }
199
-
200
- self . responseContinuation. yield ( response)
201
-
202
- if self . state == . closed {
203
- print ( " Web socket is closed, not listening for more messages. " )
204
- } else {
205
- self . setReceiveHandler ( )
206
- }
78
+ responseContinuation. yield ( response)
207
79
@unknown default :
208
80
print ( " Unknown message received " )
209
- self . setReceiveHandler ( )
210
81
}
211
- } catch {
212
- // handle the error
213
- print ( error)
214
- self . state = . closed
215
- self . responseContinuation. finish ( throwing: error)
216
82
}
83
+ } catch {
84
+ responseContinuation. finish ( throwing: error)
217
85
}
86
+ responseContinuation. finish ( )
87
+ }
88
+
89
+ private func sendSetupMessage( ) async throws {
90
+ let setup = BidiGenerateContentSetup (
91
+ model: modelResourceName, generationConfig: generationConfig
92
+ )
93
+ let message = BidiGenerateContentClientMessage . setup ( setup)
94
+ let messageData = try jsonEncoder. encode ( message)
95
+ try await webSocket. send ( . data( messageData) )
218
96
}
219
97
}
0 commit comments