Skip to content

Commit 850914e

Browse files
committed
Migrate to protocol instead of enum
1 parent c5b1567 commit 850914e

File tree

7 files changed

+64
-85
lines changed

7 files changed

+64
-85
lines changed

FirebaseAI/Sources/Types/Internal/Live/BidiGenerateContentServerMessage.swift

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct BidiGenerateContentServerMessage: Sendable {
4040
}
4141

4242
/// The message type.
43-
let messageType: MessageType
43+
let messageType: MessageType?
4444

4545
/// Usage metadata about the response(s).
4646
let usageMetadata: GenerateContentResponse.UsageMetadata?
@@ -86,11 +86,7 @@ extension BidiGenerateContentServerMessage: Decodable {
8686
} else if let goAway = try container.decodeIfPresent(GoAway.self, forKey: .goAway) {
8787
messageType = .goAway(goAway)
8888
} else {
89-
let context = DecodingError.Context(
90-
codingPath: decoder.codingPath,
91-
debugDescription: "Could not decode server message."
92-
)
93-
throw DecodingError.dataCorrupted(context)
89+
messageType = nil
9490
}
9591

9692
usageMetadata = try container.decodeIfPresent(

FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,18 @@ actor LiveSessionService {
206206
from: message
207207
)
208208
} catch {
209-
throw LiveSessionUnsupportedMessageError(underlyingError: error)
209+
let error = LiveSessionUnsupportedMessageError(underlyingError: error)
210+
211+
// if we've already finished setting up, then only surface the error through responses
212+
// otherwise, make the setup task error as well
213+
if !resumed {
214+
setupComplete.resume(throwing: error)
215+
}
216+
throw error
210217
}
211218

212-
if case .setupComplete = response.messageType {
219+
switch response.messageType {
220+
case .setupComplete:
213221
if resumed {
214222
AILog.debug(
215223
code: .duplicateLiveSessionSetupComplete,
@@ -221,17 +229,33 @@ actor LiveSessionService {
221229
resumed = true
222230
setupComplete.resume()
223231
}
224-
} else if let liveMessage = LiveServerMessage(from: response) {
225-
if case let .goingAwayNotice(message) = liveMessage.messageType {
226-
// TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
227-
AILog.debug(
228-
code: .liveSessionGoingAwaySoon,
229-
"Session expires in: \(message.goAway.timeLeft?.timeInterval ?? 0)"
230-
)
231-
}
232-
233-
responseContinuation.yield(liveMessage)
234-
} else {
232+
case let .goAway(goAway):
233+
// TODO: (b/444045023) When auto session resumption is enabled, call `connect` again
234+
AILog.debug(
235+
code: .liveSessionGoingAwaySoon,
236+
"Session expires in: \(goAway.timeLeft?.timeInterval ?? 0)"
237+
)
238+
239+
responseContinuation.yield(LiveServerGoingAwayNotice(
240+
goAway,
241+
usageMetadata: response.usageMetadata
242+
))
243+
case let .serverContent(serverContent):
244+
responseContinuation.yield(LiveServerContent(
245+
serverContent,
246+
usageMetadata: response.usageMetadata
247+
))
248+
case let .toolCall(toolCall):
249+
responseContinuation.yield(LiveServerToolCall(
250+
toolCall,
251+
usageMetadata: response.usageMetadata
252+
))
253+
case let .toolCallCancellation(toolCallCancellation):
254+
responseContinuation.yield(LiveServerToolCallCancellation(
255+
toolCallCancellation,
256+
usageMetadata: response.usageMetadata
257+
))
258+
case .none:
235259
// we don't raise an error, since this allows us to add support internally but not
236260
// publicly. We still log it in debug though, in case it's not expected.
237261
AILog.debug(

FirebaseAI/Sources/Types/Public/Live/LiveServerContent.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
/// may choose to buffer and play it out in realtime.
2020
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
2121
@available(watchOS, unavailable)
22-
public struct LiveServerContent: Sendable {
22+
public struct LiveServerContent: LiveServerMessage {
2323
let serverContent: BidiGenerateContentServerContent
2424

2525
/// The content that the model has generated as part of the current
@@ -77,7 +77,11 @@ public struct LiveServerContent: Sendable {
7777
serverContent.outputTranscription.map { LiveTranscription($0) }
7878
}
7979

80-
init(_ serverContent: BidiGenerateContentServerContent) {
80+
public var usageMetadata: GenerateContentResponse.UsageMetadata?
81+
82+
init(_ serverContent: BidiGenerateContentServerContent,
83+
usageMetadata: GenerateContentResponse.UsageMetadata?) {
8184
self.serverContent = serverContent
85+
self.usageMetadata = usageMetadata
8286
}
8387
}

FirebaseAI/Sources/Types/Public/Live/LiveServerGoingAwayNotice.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@ import Foundation
1919
/// To learn more about session limits, see the docs on [Maximum session duration](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-live#maximum-session-duration)\.
2020
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
2121
@available(watchOS, unavailable)
22-
public struct LiveServerGoingAwayNotice: Sendable {
22+
public struct LiveServerGoingAwayNotice: LiveServerMessage {
2323
let goAway: GoAway
2424
/// The remaining time before the connection will be terminated as ABORTED.
2525
///
2626
/// The minimal time returned here is specified differently together with
2727
/// the rate limits for a given model.
2828
public var timeLeft: TimeInterval? { goAway.timeLeft?.timeInterval }
2929

30-
init(_ goAway: GoAway) {
30+
public var usageMetadata: GenerateContentResponse.UsageMetadata?
31+
32+
init(_ goAway: GoAway, usageMetadata: GenerateContentResponse.UsageMetadata?) {
3133
self.goAway = goAway
34+
self.usageMetadata = usageMetadata
3235
}
3336
}

FirebaseAI/Sources/Types/Public/Live/LiveServerMessage.swift

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,7 @@
1515
/// Update from the server, generated from the model in response to client messages.
1616
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
1717
@available(watchOS, unavailable)
18-
public struct LiveServerMessage: Sendable {
19-
let serverMessage: BidiGenerateContentServerMessage
20-
21-
/// The type of message sent from the server.
22-
public enum MessageType: Sendable {
23-
/// Content generated by the model in response to client messages.
24-
case content(LiveServerContent)
25-
26-
/// Request for the client to execute the provided functions.
27-
case toolCall(LiveServerToolCall)
28-
29-
/// Notification for the client that a previously issued ``LiveServerToolCall`` should be
30-
/// cancelled.
31-
case toolCallCancellation(LiveServerToolCallCancellation)
32-
33-
/// Server will disconnect soon.
34-
case goingAwayNotice(LiveServerGoingAwayNotice)
35-
}
36-
37-
/// The actual message sent from the server.
38-
public var messageType: MessageType
39-
18+
public protocol LiveServerMessage: Sendable {
4019
/// Metadata on the usage of the cached content.
41-
public var usageMetadata: GenerateContentResponse.UsageMetadata? { serverMessage.usageMetadata }
42-
}
43-
44-
// MARK: - Internal parsing
45-
46-
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
47-
@available(watchOS, unavailable)
48-
extension LiveServerMessage {
49-
init?(from serverMessage: BidiGenerateContentServerMessage) {
50-
guard let messageType = LiveServerMessage.MessageType(from: serverMessage.messageType) else {
51-
return nil
52-
}
53-
54-
self.serverMessage = serverMessage
55-
self.messageType = messageType
56-
}
57-
}
58-
59-
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
60-
@available(watchOS, unavailable)
61-
extension LiveServerMessage.MessageType {
62-
init?(from serverMessage: BidiGenerateContentServerMessage.MessageType) {
63-
switch serverMessage {
64-
case .setupComplete:
65-
// this is handled internally, and should not be surfaced to users
66-
return nil
67-
case let .serverContent(msg):
68-
self = .content(LiveServerContent(msg))
69-
case let .toolCall(msg):
70-
self = .toolCall(LiveServerToolCall(msg))
71-
case let .toolCallCancellation(msg):
72-
self = .toolCallCancellation(LiveServerToolCallCancellation(msg))
73-
case let .goAway(msg):
74-
self = .goingAwayNotice(LiveServerGoingAwayNotice(msg))
75-
}
76-
}
20+
var usageMetadata: GenerateContentResponse.UsageMetadata? { get }
7721
}

FirebaseAI/Sources/Types/Public/Live/LiveServerToolCall.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@
1818
/// correspond to individual ``FunctionCallPart``s.
1919
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
2020
@available(watchOS, unavailable)
21-
public struct LiveServerToolCall: Sendable {
21+
public struct LiveServerToolCall: LiveServerMessage {
2222
let serverToolCall: BidiGenerateContentToolCall
2323

2424
/// A list of ``FunctionCallPart`` to run and return responses for.
2525
public var functionCalls: [FunctionCallPart]? {
2626
serverToolCall.functionCalls?.map { FunctionCallPart($0) }
2727
}
2828

29-
init(_ serverToolCall: BidiGenerateContentToolCall) {
29+
public var usageMetadata: GenerateContentResponse.UsageMetadata?
30+
31+
init(_ serverToolCall: BidiGenerateContentToolCall,
32+
usageMetadata: GenerateContentResponse.UsageMetadata?) {
3033
self.serverToolCall = serverToolCall
34+
self.usageMetadata = usageMetadata
3135
}
3236
}

FirebaseAI/Sources/Types/Public/Live/LiveServerToolCallCancellation.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818
/// ``FunctionCallPart``s.
1919
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, *)
2020
@available(watchOS, unavailable)
21-
public struct LiveServerToolCallCancellation: Sendable {
21+
public struct LiveServerToolCallCancellation: LiveServerMessage {
2222
let serverToolCallCancellation: BidiGenerateContentToolCallCancellation
2323
/// A list of `functionId`s matching the `functionId` provided in a previous
2424
/// ``LiveServerToolCall``, where only the provided `functionId`s should be cancelled.
2525
public var ids: [String]? { serverToolCallCancellation.ids }
2626

27-
init(_ serverToolCallCancellation: BidiGenerateContentToolCallCancellation) {
27+
public var usageMetadata: GenerateContentResponse.UsageMetadata?
28+
29+
init(_ serverToolCallCancellation: BidiGenerateContentToolCallCancellation,
30+
usageMetadata: GenerateContentResponse.UsageMetadata?) {
2831
self.serverToolCallCancellation = serverToolCallCancellation
32+
self.usageMetadata = usageMetadata
2933
}
3034
}

0 commit comments

Comments
 (0)