Skip to content

Commit f27e34d

Browse files
authored
[Vertex AI] Replace ModelContent.Part enum with protocol/structs (#13767)
1 parent a3e7a20 commit f27e34d

22 files changed

+697
-495
lines changed

FirebaseVertexAI.podspec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Firebase SDK.
6262
]
6363
unit_tests.resources = [
6464
unit_tests_dir + 'vertexai-sdk-test-data/mock-responses/**/*.{txt,json}',
65+
unit_tests_dir + 'Resources/**/*',
6566
]
6667
end
6768
end

FirebaseVertexAI/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
- [changed] **Breaking Change**: The `CountTokensError` enum has been removed;
2929
errors occurring in `GenerativeModel.countTokens(...)` are now thrown directly
3030
instead of being wrapped in a `CountTokensError.internalError`. (#13736)
31+
- [changed] **Breaking Change**: The enum `ModelContent.Part` has been replaced
32+
with a protocol named `Part` to avoid future breaking changes with new part
33+
types. The new types `TextPart` and `FunctionCallPart` may be received when
34+
generating content the types `TextPart`; additionally the types
35+
`InlineDataPart`, `FileDataPart` and `FunctionResponsePart` may be provided
36+
as input. (#13767)
3137
- [changed] The default request timeout is now 180 seconds instead of the
3238
platform-default value of 60 seconds for a `URLRequest`; this timeout may
3339
still be customized in `RequestOptions`. (#13722)

FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class FunctionCallingViewModel: ObservableObject {
3030
}
3131

3232
/// Function calls pending processing
33-
private var functionCalls = [FunctionCall]()
33+
private var functionCalls = [FunctionCallPart]()
3434

3535
private var model: GenerativeModel
3636
private var chat: Chat
@@ -144,26 +144,26 @@ class FunctionCallingViewModel: ObservableObject {
144144

145145
for part in candidate.content.parts {
146146
switch part {
147-
case let .text(text):
147+
case let textPart as TextPart:
148148
// replace pending message with backend response
149-
messages[messages.count - 1].message += text
149+
messages[messages.count - 1].message += textPart.text
150150
messages[messages.count - 1].pending = false
151-
case let .functionCall(functionCall):
152-
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
153-
functionCalls.append(functionCall)
154-
case .inlineData, .fileData, .functionResponse:
155-
fatalError("Unsupported response content.")
151+
case let functionCallPart as FunctionCallPart:
152+
messages.insert(functionCallPart.chatMessage(), at: messages.count - 1)
153+
functionCalls.append(functionCallPart)
154+
default:
155+
fatalError("Unsupported response part: \(part)")
156156
}
157157
}
158158
}
159159

160-
func processFunctionCalls() async throws -> [FunctionResponse] {
161-
var functionResponses = [FunctionResponse]()
160+
func processFunctionCalls() async throws -> [FunctionResponsePart] {
161+
var functionResponses = [FunctionResponsePart]()
162162
for functionCall in functionCalls {
163163
switch functionCall.name {
164164
case "get_exchange_rate":
165165
let exchangeRates = getExchangeRate(args: functionCall.args)
166-
functionResponses.append(FunctionResponse(
166+
functionResponses.append(FunctionResponsePart(
167167
name: "get_exchange_rate",
168168
response: exchangeRates
169169
))
@@ -208,7 +208,7 @@ class FunctionCallingViewModel: ObservableObject {
208208
}
209209
}
210210

211-
private extension FunctionCall {
211+
private extension FunctionCallPart {
212212
func chatMessage() -> ChatMessage {
213213
let encoder = JSONEncoder()
214214
encoder.outputFormatting = .prettyPrinted
@@ -228,7 +228,7 @@ private extension FunctionCall {
228228
}
229229
}
230230

231-
private extension FunctionResponse {
231+
private extension FunctionResponsePart {
232232
func chatMessage() -> ChatMessage {
233233
let encoder = JSONEncoder()
234234
encoder.outputFormatting = .prettyPrinted
@@ -248,12 +248,8 @@ private extension FunctionResponse {
248248
}
249249
}
250250

251-
private extension [FunctionResponse] {
251+
private extension [FunctionResponsePart] {
252252
func modelContent() -> [ModelContent] {
253-
return self.map { ModelContent(
254-
role: "function",
255-
parts: [ModelContent.Part.functionResponse($0)]
256-
)
257-
}
253+
return self.map { ModelContent(role: "function", parts: [$0]) }
258254
}
259255
}

FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class PhotoReasoningViewModel: ObservableObject {
6262

6363
let prompt = "Look at the image(s), and then answer the following question: \(userInput)"
6464

65-
var images = [any ThrowingPartsRepresentable]()
65+
var images = [any PartsRepresentable]()
6666
for item in selectedItems {
6767
if let data = try? await item.loadTransferable(type: Data.self) {
6868
guard let image = UIImage(data: data) else {

FirebaseVertexAI/Sources/Chat.swift

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class Chat {
3535
/// - Parameter parts: The new content to send as a single chat message.
3636
/// - Returns: The model's response if no error occurred.
3737
/// - Throws: A ``GenerateContentError`` if an error occurred.
38-
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws
38+
public func sendMessage(_ parts: any PartsRepresentable...) async throws
3939
-> GenerateContentResponse {
4040
return try await sendMessage([ModelContent(parts: parts)])
4141
}
@@ -45,19 +45,10 @@ public class Chat {
4545
/// - Parameter content: The new content to send as a single chat message.
4646
/// - Returns: The model's response if no error occurred.
4747
/// - Throws: A ``GenerateContentError`` if an error occurred.
48-
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws
48+
public func sendMessage(_ content: [ModelContent]) async throws
4949
-> GenerateContentResponse {
5050
// Ensure that the new content has the role set.
51-
let newContent: [ModelContent]
52-
do {
53-
newContent = try content().map(populateContentRole(_:))
54-
} catch let underlying {
55-
if let contentError = underlying as? ImageConversionError {
56-
throw GenerateContentError.promptImageContentError(underlying: contentError)
57-
} else {
58-
throw GenerateContentError.internalError(underlying: underlying)
59-
}
60-
}
51+
let newContent = content.map(populateContentRole(_:))
6152

6253
// Send the history alongside the new message as context.
6354
let request = history + newContent
@@ -85,7 +76,7 @@ public class Chat {
8576
/// - Parameter parts: The new content to send as a single chat message.
8677
/// - Returns: A stream containing the model's response or an error if an error occurred.
8778
@available(macOS 12.0, *)
88-
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) throws
79+
public func sendMessageStream(_ parts: any PartsRepresentable...) throws
8980
-> AsyncThrowingStream<GenerateContentResponse, Error> {
9081
return try sendMessageStream([ModelContent(parts: parts)])
9182
}
@@ -95,24 +86,14 @@ public class Chat {
9586
/// - Parameter content: The new content to send as a single chat message.
9687
/// - Returns: A stream containing the model's response or an error if an error occurred.
9788
@available(macOS 12.0, *)
98-
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) throws
89+
public func sendMessageStream(_ content: [ModelContent]) throws
9990
-> AsyncThrowingStream<GenerateContentResponse, Error> {
100-
let resolvedContent: [ModelContent]
101-
do {
102-
resolvedContent = try content()
103-
} catch let underlying {
104-
if let contentError = underlying as? ImageConversionError {
105-
throw GenerateContentError.promptImageContentError(underlying: contentError)
106-
}
107-
throw GenerateContentError.internalError(underlying: underlying)
108-
}
109-
11091
return AsyncThrowingStream { continuation in
11192
Task {
11293
var aggregatedContent: [ModelContent] = []
11394

11495
// Ensure that the new content has the role set.
115-
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))
96+
let newContent: [ModelContent] = content.map(populateContentRole(_:))
11697

11798
// Send the history alongside the new message as context.
11899
let request = history + newContent
@@ -146,20 +127,20 @@ public class Chat {
146127
}
147128

148129
private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent {
149-
var parts: [ModelContent.Part] = []
130+
var parts: [any Part] = []
150131
var combinedText = ""
151132
for aggregate in chunks {
152133
// Loop through all the parts, aggregating the text and adding the images.
153134
for part in aggregate.parts {
154135
switch part {
155-
case let .text(str):
156-
combinedText += str
136+
case let textPart as TextPart:
137+
combinedText += textPart.text
157138

158-
case .inlineData, .fileData, .functionCall, .functionResponse:
139+
default:
159140
// Don't combine it, just add to the content. If there's any text pending, add that as
160141
// a part.
161142
if !combinedText.isEmpty {
162-
parts.append(.text(combinedText))
143+
parts.append(TextPart(combinedText))
163144
combinedText = ""
164145
}
165146

@@ -169,7 +150,7 @@ public class Chat {
169150
}
170151

171152
if !combinedText.isEmpty {
172-
parts.append(.text(combinedText))
153+
parts.append(TextPart(combinedText))
173154
}
174155

175156
return ModelContent(role: "model", parts: parts)

FirebaseVertexAI/Sources/FunctionCalling.swift

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,6 @@
1414

1515
import Foundation
1616

17-
/// A predicted function call returned from the model.
18-
public struct FunctionCall: Equatable, Sendable {
19-
/// The name of the function to call.
20-
public let name: String
21-
22-
/// The function parameters and values.
23-
public let args: JSONObject
24-
25-
/// Constructs a new function call.
26-
///
27-
/// > Note: A `FunctionCall` is typically received from the model, rather than created manually.
28-
///
29-
/// - Parameters:
30-
/// - name: The name of the function to call.
31-
/// - args: The function parameters and values.
32-
public init(name: String, args: JSONObject) {
33-
self.name = name
34-
self.args = args
35-
}
36-
}
37-
3817
/// Structured representation of a function declaration.
3918
///
4019
/// This `FunctionDeclaration` is a representation of a block of code that can be used as a ``Tool``
@@ -136,50 +115,8 @@ public struct ToolConfig {
136115
}
137116
}
138117

139-
/// Result output from a ``FunctionCall``.
140-
///
141-
/// Contains a string representing the `FunctionDeclaration.name` and a structured JSON object
142-
/// containing any output from the function is used as context to the model. This should contain the
143-
/// result of a ``FunctionCall`` made based on model prediction.
144-
public struct FunctionResponse: Equatable, Sendable {
145-
/// The name of the function that was called.
146-
let name: String
147-
148-
/// The function's response.
149-
let response: JSONObject
150-
151-
/// Constructs a new `FunctionResponse`.
152-
///
153-
/// - Parameters:
154-
/// - name: The name of the function that was called.
155-
/// - response: The function's response.
156-
public init(name: String, response: JSONObject) {
157-
self.name = name
158-
self.response = response
159-
}
160-
}
161-
162118
// MARK: - Codable Conformance
163119

164-
extension FunctionCall: Decodable {
165-
enum CodingKeys: CodingKey {
166-
case name
167-
case args
168-
}
169-
170-
public init(from decoder: Decoder) throws {
171-
let container = try decoder.container(keyedBy: CodingKeys.self)
172-
name = try container.decode(String.self, forKey: .name)
173-
if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) {
174-
self.args = args
175-
} else {
176-
args = JSONObject()
177-
}
178-
}
179-
}
180-
181-
extension FunctionCall: Encodable {}
182-
183120
extension FunctionDeclaration: Encodable {
184121
enum CodingKeys: String, CodingKey {
185122
case name
@@ -202,5 +139,3 @@ extension FunctionCallingConfig: Encodable {}
202139
extension FunctionCallingConfig.Mode: Encodable {}
203140

204141
extension ToolConfig: Encodable {}
205-
206-
extension FunctionResponse: Codable {}

FirebaseVertexAI/Sources/GenerateContentResponse.swift

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ public struct GenerateContentResponse: Sendable {
4949
return nil
5050
}
5151
let textValues: [String] = candidate.content.parts.compactMap { part in
52-
guard case let .text(text) = part else {
52+
switch part {
53+
case let textPart as TextPart:
54+
return textPart.text
55+
default:
5356
return nil
5457
}
55-
return text
5658
}
5759
guard textValues.count > 0 else {
5860
VertexLog.error(
@@ -65,15 +67,17 @@ public struct GenerateContentResponse: Sendable {
6567
}
6668

6769
/// Returns function calls found in any `Part`s of the first candidate of the response, if any.
68-
public var functionCalls: [FunctionCall] {
70+
public var functionCalls: [FunctionCallPart] {
6971
guard let candidate = candidates.first else {
7072
return []
7173
}
7274
return candidate.content.parts.compactMap { part in
73-
guard case let .functionCall(functionCall) = part else {
75+
switch part {
76+
case let functionCallPart as FunctionCallPart:
77+
return functionCallPart
78+
default:
7479
return nil
7580
}
76-
return functionCall
7781
}
7882
}
7983

0 commit comments

Comments
 (0)