diff --git a/FirebaseVertexAI.podspec b/FirebaseVertexAI.podspec index a9ee6fb77eb..0a9b059d3c1 100644 --- a/FirebaseVertexAI.podspec +++ b/FirebaseVertexAI.podspec @@ -62,6 +62,7 @@ Firebase SDK. ] unit_tests.resources = [ unit_tests_dir + 'vertexai-sdk-test-data/mock-responses/**/*.{txt,json}', + unit_tests_dir + 'Resources/**/*', ] end end diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index 9a7d5c70e7c..13ef2aad590 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -28,6 +28,12 @@ - [changed] **Breaking Change**: The `CountTokensError` enum has been removed; errors occurring in `GenerativeModel.countTokens(...)` are now thrown directly instead of being wrapped in a `CountTokensError.internalError`. (#13736) +- [changed] **Breaking Change**: The enum `ModelContent.Part` has been replaced + with a protocol named `Part` to avoid future breaking changes with new part + types. The new types `TextPart` and `FunctionCallPart` may be received when + generating content the types `TextPart`; additionally the types + `InlineDataPart`, `FileDataPart` and `FunctionResponsePart` may be provided + as input. (#13767) - [changed] The default request timeout is now 180 seconds instead of the platform-default value of 60 seconds for a `URLRequest`; this timeout may still be customized in `RequestOptions`. (#13722) diff --git a/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift index 110cab9ce27..f39540eb1a9 100644 --- a/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift +++ b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -30,7 +30,7 @@ class FunctionCallingViewModel: ObservableObject { } /// Function calls pending processing - private var functionCalls = [FunctionCall]() + private var functionCalls = [FunctionCallPart]() private var model: GenerativeModel private var chat: Chat @@ -144,26 +144,26 @@ class FunctionCallingViewModel: ObservableObject { for part in candidate.content.parts { switch part { - case let .text(text): + case let textPart as TextPart: // replace pending message with backend response - messages[messages.count - 1].message += text + messages[messages.count - 1].message += textPart.text messages[messages.count - 1].pending = false - case let .functionCall(functionCall): - messages.insert(functionCall.chatMessage(), at: messages.count - 1) - functionCalls.append(functionCall) - case .inlineData, .fileData, .functionResponse: - fatalError("Unsupported response content.") + case let functionCallPart as FunctionCallPart: + messages.insert(functionCallPart.chatMessage(), at: messages.count - 1) + functionCalls.append(functionCallPart) + default: + fatalError("Unsupported response part: \(part)") } } } - func processFunctionCalls() async throws -> [FunctionResponse] { - var functionResponses = [FunctionResponse]() + func processFunctionCalls() async throws -> [FunctionResponsePart] { + var functionResponses = [FunctionResponsePart]() for functionCall in functionCalls { switch functionCall.name { case "get_exchange_rate": let exchangeRates = getExchangeRate(args: functionCall.args) - functionResponses.append(FunctionResponse( + functionResponses.append(FunctionResponsePart( name: "get_exchange_rate", response: exchangeRates )) @@ -208,7 +208,7 @@ class FunctionCallingViewModel: ObservableObject { } } -private extension FunctionCall { +private extension FunctionCallPart { func chatMessage() -> ChatMessage { let encoder = JSONEncoder() encoder.outputFormatting = .prettyPrinted @@ -228,7 +228,7 @@ private extension FunctionCall { } } -private extension FunctionResponse { +private extension FunctionResponsePart { func chatMessage() -> ChatMessage { let encoder = JSONEncoder() encoder.outputFormatting = .prettyPrinted @@ -248,12 +248,8 @@ private extension FunctionResponse { } } -private extension [FunctionResponse] { +private extension [FunctionResponsePart] { func modelContent() -> [ModelContent] { - return self.map { ModelContent( - role: "function", - parts: [ModelContent.Part.functionResponse($0)] - ) - } + return self.map { ModelContent(role: "function", parts: [$0]) } } } diff --git a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index 3a85da05102..e433d96dfc2 100644 --- a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -62,7 +62,7 @@ class PhotoReasoningViewModel: ObservableObject { let prompt = "Look at the image(s), and then answer the following question: \(userInput)" - var images = [any ThrowingPartsRepresentable]() + var images = [any PartsRepresentable]() for item in selectedItems { if let data = try? await item.loadTransferable(type: Data.self) { guard let image = UIImage(data: data) else { diff --git a/FirebaseVertexAI/Sources/Chat.swift b/FirebaseVertexAI/Sources/Chat.swift index 10df040ab30..2ebab217ca8 100644 --- a/FirebaseVertexAI/Sources/Chat.swift +++ b/FirebaseVertexAI/Sources/Chat.swift @@ -35,7 +35,7 @@ public class Chat { /// - Parameter parts: The new content to send as a single chat message. /// - Returns: The model's response if no error occurred. /// - Throws: A ``GenerateContentError`` if an error occurred. - public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws + public func sendMessage(_ parts: any PartsRepresentable...) async throws -> GenerateContentResponse { return try await sendMessage([ModelContent(parts: parts)]) } @@ -45,19 +45,10 @@ public class Chat { /// - Parameter content: The new content to send as a single chat message. /// - Returns: The model's response if no error occurred. /// - Throws: A ``GenerateContentError`` if an error occurred. - public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws + public func sendMessage(_ content: [ModelContent]) async throws -> GenerateContentResponse { // Ensure that the new content has the role set. - let newContent: [ModelContent] - do { - newContent = try content().map(populateContentRole(_:)) - } catch let underlying { - if let contentError = underlying as? ImageConversionError { - throw GenerateContentError.promptImageContentError(underlying: contentError) - } else { - throw GenerateContentError.internalError(underlying: underlying) - } - } + let newContent = content.map(populateContentRole(_:)) // Send the history alongside the new message as context. let request = history + newContent @@ -85,7 +76,7 @@ public class Chat { /// - Parameter parts: The new content to send as a single chat message. /// - Returns: A stream containing the model's response or an error if an error occurred. @available(macOS 12.0, *) - public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) throws + public func sendMessageStream(_ parts: any PartsRepresentable...) throws -> AsyncThrowingStream { return try sendMessageStream([ModelContent(parts: parts)]) } @@ -95,24 +86,14 @@ public class Chat { /// - Parameter content: The new content to send as a single chat message. /// - Returns: A stream containing the model's response or an error if an error occurred. @available(macOS 12.0, *) - public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) throws + public func sendMessageStream(_ content: [ModelContent]) throws -> AsyncThrowingStream { - let resolvedContent: [ModelContent] - do { - resolvedContent = try content() - } catch let underlying { - if let contentError = underlying as? ImageConversionError { - throw GenerateContentError.promptImageContentError(underlying: contentError) - } - throw GenerateContentError.internalError(underlying: underlying) - } - return AsyncThrowingStream { continuation in Task { var aggregatedContent: [ModelContent] = [] // Ensure that the new content has the role set. - let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:)) + let newContent: [ModelContent] = content.map(populateContentRole(_:)) // Send the history alongside the new message as context. let request = history + newContent @@ -146,20 +127,20 @@ public class Chat { } private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { - var parts: [ModelContent.Part] = [] + var parts: [any Part] = [] var combinedText = "" for aggregate in chunks { // Loop through all the parts, aggregating the text and adding the images. for part in aggregate.parts { switch part { - case let .text(str): - combinedText += str + case let textPart as TextPart: + combinedText += textPart.text - case .inlineData, .fileData, .functionCall, .functionResponse: + default: // Don't combine it, just add to the content. If there's any text pending, add that as // a part. if !combinedText.isEmpty { - parts.append(.text(combinedText)) + parts.append(TextPart(combinedText)) combinedText = "" } @@ -169,7 +150,7 @@ public class Chat { } if !combinedText.isEmpty { - parts.append(.text(combinedText)) + parts.append(TextPart(combinedText)) } return ModelContent(role: "model", parts: parts) diff --git a/FirebaseVertexAI/Sources/FunctionCalling.swift b/FirebaseVertexAI/Sources/FunctionCalling.swift index 3fb17838d4c..69924f3cc4b 100644 --- a/FirebaseVertexAI/Sources/FunctionCalling.swift +++ b/FirebaseVertexAI/Sources/FunctionCalling.swift @@ -14,27 +14,6 @@ import Foundation -/// A predicted function call returned from the model. -public struct FunctionCall: Equatable, Sendable { - /// The name of the function to call. - public let name: String - - /// The function parameters and values. - public let args: JSONObject - - /// Constructs a new function call. - /// - /// > Note: A `FunctionCall` is typically received from the model, rather than created manually. - /// - /// - Parameters: - /// - name: The name of the function to call. - /// - args: The function parameters and values. - public init(name: String, args: JSONObject) { - self.name = name - self.args = args - } -} - /// Structured representation of a function declaration. /// /// This `FunctionDeclaration` is a representation of a block of code that can be used as a ``Tool`` @@ -136,50 +115,8 @@ public struct ToolConfig { } } -/// Result output from a ``FunctionCall``. -/// -/// Contains a string representing the `FunctionDeclaration.name` and a structured JSON object -/// containing any output from the function is used as context to the model. This should contain the -/// result of a ``FunctionCall`` made based on model prediction. -public struct FunctionResponse: Equatable, Sendable { - /// The name of the function that was called. - let name: String - - /// The function's response. - let response: JSONObject - - /// Constructs a new `FunctionResponse`. - /// - /// - Parameters: - /// - name: The name of the function that was called. - /// - response: The function's response. - public init(name: String, response: JSONObject) { - self.name = name - self.response = response - } -} - // MARK: - Codable Conformance -extension FunctionCall: Decodable { - enum CodingKeys: CodingKey { - case name - case args - } - - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - name = try container.decode(String.self, forKey: .name) - if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) { - self.args = args - } else { - args = JSONObject() - } - } -} - -extension FunctionCall: Encodable {} - extension FunctionDeclaration: Encodable { enum CodingKeys: String, CodingKey { case name @@ -202,5 +139,3 @@ extension FunctionCallingConfig: Encodable {} extension FunctionCallingConfig.Mode: Encodable {} extension ToolConfig: Encodable {} - -extension FunctionResponse: Codable {} diff --git a/FirebaseVertexAI/Sources/GenerateContentResponse.swift b/FirebaseVertexAI/Sources/GenerateContentResponse.swift index 66dd83aec16..c4ca48f2264 100644 --- a/FirebaseVertexAI/Sources/GenerateContentResponse.swift +++ b/FirebaseVertexAI/Sources/GenerateContentResponse.swift @@ -49,10 +49,12 @@ public struct GenerateContentResponse: Sendable { return nil } let textValues: [String] = candidate.content.parts.compactMap { part in - guard case let .text(text) = part else { + switch part { + case let textPart as TextPart: + return textPart.text + default: return nil } - return text } guard textValues.count > 0 else { VertexLog.error( @@ -65,15 +67,17 @@ public struct GenerateContentResponse: Sendable { } /// Returns function calls found in any `Part`s of the first candidate of the response, if any. - public var functionCalls: [FunctionCall] { + public var functionCalls: [FunctionCallPart] { guard let candidate = candidates.first else { return [] } return candidate.content.parts.compactMap { part in - guard case let .functionCall(functionCall) = part else { + switch part { + case let functionCallPart as FunctionCallPart: + return functionCallPart + default: return nil } - return functionCall } } diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index dc069d88d03..0ac58732e11 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -106,12 +106,11 @@ public final class GenerativeModel { /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) /// prompts, see `generateContent(_ content: @autoclosure () throws -> [ModelContent])`. /// - /// - Parameter content: The input(s) given to the model as a prompt (see - /// ``ThrowingPartsRepresentable`` + /// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable`` /// for conforming types). /// - Returns: The content generated by the model. /// - Throws: A ``GenerateContentError`` if the request failed. - public func generateContent(_ parts: any ThrowingPartsRepresentable...) + public func generateContent(_ parts: any PartsRepresentable...) async throws -> GenerateContentResponse { return try await generateContent([ModelContent(parts: parts)]) } @@ -121,24 +120,22 @@ public final class GenerativeModel { /// - Parameter content: The input(s) given to the model as a prompt. /// - Returns: The generated content response from the model. /// - Throws: A ``GenerateContentError`` if the request failed. - public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws + public func generateContent(_ content: [ModelContent]) async throws -> GenerateContentResponse { + try content.throwIfError() let response: GenerateContentResponse + let generateContentRequest = GenerateContentRequest(model: modelResourceName, + contents: content, + generationConfig: generationConfig, + safetySettings: safetySettings, + tools: tools, + toolConfig: toolConfig, + systemInstruction: systemInstruction, + isStreaming: false, + options: requestOptions) do { - let generateContentRequest = try GenerateContentRequest(model: modelResourceName, - contents: content(), - generationConfig: generationConfig, - safetySettings: safetySettings, - tools: tools, - toolConfig: toolConfig, - systemInstruction: systemInstruction, - isStreaming: false, - options: requestOptions) response = try await generativeAIService.loadRequest(request: generateContentRequest) } catch { - if let imageError = error as? ImageConversionError { - throw GenerateContentError.promptImageContentError(underlying: imageError) - } throw GenerativeModel.generateContentError(from: error) } @@ -166,12 +163,11 @@ public final class GenerativeModel { /// prompts, see `generateContentStream(_ content: @autoclosure () throws -> [ModelContent])`. /// /// - Parameter content: The input(s) given to the model as a prompt (see - /// ``ThrowingPartsRepresentable`` - /// for conforming types). + /// ``PartsRepresentable`` for conforming types). /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError`` /// error if an error occurred. @available(macOS 12.0, *) - public func generateContentStream(_ parts: any ThrowingPartsRepresentable...) throws + public func generateContentStream(_ parts: any PartsRepresentable...) throws -> AsyncThrowingStream { return try generateContentStream([ModelContent(parts: parts)]) } @@ -182,20 +178,11 @@ public final class GenerativeModel { /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError`` /// error if an error occurred. @available(macOS 12.0, *) - public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent]) throws + public func generateContentStream(_ content: [ModelContent]) throws -> AsyncThrowingStream { - let evaluatedContent: [ModelContent] - do { - evaluatedContent = try content() - } catch let underlying { - if let contentError = underlying as? ImageConversionError { - throw GenerateContentError.promptImageContentError(underlying: contentError) - } - throw GenerateContentError.internalError(underlying: underlying) - } - + try content.throwIfError() let generateContentRequest = GenerateContentRequest(model: modelResourceName, - contents: evaluatedContent, + contents: content, generationConfig: generationConfig, safetySettings: safetySettings, tools: tools, @@ -249,13 +236,12 @@ public final class GenerativeModel { /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) /// input, see `countTokens(_ content: @autoclosure () throws -> [ModelContent])`. /// - /// - Parameter content: The input(s) given to the model as a prompt (see - /// ``ThrowingPartsRepresentable`` + /// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable`` /// for conforming types). /// - Returns: The results of running the model's tokenizer on the input; contains /// ``CountTokensResponse/totalTokens``. /// - Throws: A ``CountTokensError`` if the tokenization request failed. - public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws + public func countTokens(_ parts: any PartsRepresentable...) async throws -> CountTokensResponse { return try await countTokens([ModelContent(parts: parts)]) } @@ -267,11 +253,11 @@ public final class GenerativeModel { /// ``CountTokensResponse/totalTokens``. /// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was /// invalid. - public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws + public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse { - let countTokensRequest = try CountTokensRequest( + let countTokensRequest = CountTokensRequest( model: modelResourceName, - contents: content(), + contents: content, systemInstruction: systemInstruction, tools: tools, generationConfig: generationConfig, diff --git a/FirebaseVertexAI/Sources/ModelContent.swift b/FirebaseVertexAI/Sources/ModelContent.swift index f5699a600fb..d215dd4ba15 100644 --- a/FirebaseVertexAI/Sources/ModelContent.swift +++ b/FirebaseVertexAI/Sources/ModelContent.swift @@ -14,60 +14,34 @@ import Foundation +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension [ModelContent] { + // TODO: Rename and refactor this. + func throwIfError() throws { + for content in self { + for part in content.parts { + switch part { + case let errorPart as ErrorPart: + throw errorPart.error + default: + break + } + } + } + } +} + /// A type describing data in media formats interpretable by an AI model. Each generative AI /// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value /// may comprise multiple heterogeneous ``ModelContent/Part``s. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct ModelContent: Equatable, Sendable { - /// A discrete piece of data in a media format interpretable by an AI model. Within a single value - /// of ``Part``, different data types may not mix. - public enum Part: Equatable, Sendable { - /// Text value. + enum InternalPart: Equatable, Sendable { case text(String) - - /// Data with a specified media type. Not all media types may be supported by the AI model. case inlineData(mimetype: String, Data) - - /// File data stored in Cloud Storage for Firebase, referenced by URI. - /// - /// > Note: Supported media types depends on the model; see [media requirements - /// > ](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#media_requirements) - /// > for details. - /// - /// - Parameters: - /// - mimetype: The IANA standard MIME type of the uploaded file, for example, `"image/jpeg"` - /// or `"video/mp4"`; see [media requirements - /// ](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#media_requirements) - /// for supported values. - /// - uri: The `"gs://"`-prefixed URI of the file in Cloud Storage for Firebase, for example, - /// `"gs://bucket-name/path/image.jpg"`. case fileData(mimetype: String, uri: String) - - /// A predicted function call returned from the model. case functionCall(FunctionCall) - - /// A response to a function call. case functionResponse(FunctionResponse) - - // MARK: Convenience Initializers - - /// Convenience function for populating a Part with JPEG data. - public static func jpeg(_ data: Data) -> Self { - return .inlineData(mimetype: "image/jpeg", data) - } - - /// Convenience function for populating a Part with PNG data. - public static func png(_ data: Data) -> Self { - return .inlineData(mimetype: "image/png", data) - } - - /// Returns the text contents of this ``Part``, if it contains text. - public var text: String? { - switch self { - case let .text(contents): return contents - default: return nil - } - } } /// The role of the entity creating the ``ModelContent``. For user-generated client requests, @@ -75,39 +49,88 @@ public struct ModelContent: Equatable, Sendable { public let role: String? /// The data parts comprising this ``ModelContent`` value. - public let parts: [Part] - - /// Creates a new value from any data or `Array` of data interpretable as a - /// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s. - public init(role: String? = "user", parts: some ThrowingPartsRepresentable) throws { - self.role = role - try self.parts = parts.tryPartsValue() + public var parts: [any Part] { + var convertedParts = [any Part]() + for part in internalParts { + switch part { + case let .text(text): + convertedParts.append(TextPart(text)) + case let .inlineData(mimetype, data): + convertedParts.append(InlineDataPart(data: data, mimeType: mimetype)) + case let .fileData(mimetype, uri): + convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype)) + case let .functionCall(functionCall): + convertedParts.append(FunctionCallPart(functionCall)) + case let .functionResponse(functionResponse): + convertedParts.append(FunctionResponsePart(functionResponse)) + } + } + return convertedParts } + // TODO: Refactor this + let internalParts: [InternalPart] + /// Creates a new value from any data or `Array` of data interpretable as a - /// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s. + /// ``Part``. See ``PartsRepresentable`` for types that can be interpreted as `Part`s. public init(role: String? = "user", parts: some PartsRepresentable) { self.role = role - self.parts = parts.partsValue + var convertedParts = [InternalPart]() + for part in parts.partsValue { + switch part { + case let textPart as TextPart: + convertedParts.append(.text(textPart.text)) + case let inlineDataPart as InlineDataPart: + let inlineData = inlineDataPart.inlineData + convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data)) + case let fileDataPart as FileDataPart: + let fileData = fileDataPart.fileData + convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)) + case let functionCallPart as FunctionCallPart: + convertedParts.append(.functionCall(functionCallPart.functionCall)) + case let functionResponsePart as FunctionResponsePart: + convertedParts.append(.functionResponse(functionResponsePart.functionResponse)) + default: + fatalError() + } + } + internalParts = convertedParts } /// Creates a new value from a list of ``Part``s. - public init(role: String? = "user", parts: [Part]) { + public init(role: String? = "user", parts: [any Part]) { self.role = role - self.parts = parts + var convertedParts = [InternalPart]() + for part in parts { + switch part { + case let textPart as TextPart: + convertedParts.append(.text(textPart.text)) + case let inlineDataPart as InlineDataPart: + let inlineData = inlineDataPart.inlineData + convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data)) + case let fileDataPart as FileDataPart: + let fileData = fileDataPart.fileData + convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)) + case let functionCallPart as FunctionCallPart: + convertedParts.append(.functionCall(functionCallPart.functionCall)) + case let functionResponsePart as FunctionResponsePart: + convertedParts.append(.functionResponse(functionResponsePart.functionResponse)) + default: + fatalError() + } + } + internalParts = convertedParts } - /// Creates a new value from any data interpretable as a ``Part``. See - /// ``ThrowingPartsRepresentable`` - /// for types that can be interpreted as `Part`s. - public init(role: String? = "user", _ parts: any ThrowingPartsRepresentable...) throws { - let content = try parts.flatMap { try $0.tryPartsValue() } + /// Creates a new value from any data interpretable as a ``Part``. + /// See ``PartsRepresentable`` for types that can be interpreted as `Part`s. + public init(role: String? = "user", _ parts: any PartsRepresentable...) { + let content = parts.flatMap { $0.partsValue } self.init(role: role, parts: content) } - /// Creates a new value from any data interpretable as a ``Part``. See - /// ``ThrowingPartsRepresentable`` - /// for types that can be interpreted as `Part`s. + /// Creates a new value from any data interpretable as a ``Part``. + /// See ``PartsRepresentable``for types that can be interpreted as `Part`s. public init(role: String? = "user", _ parts: [PartsRepresentable]) { let content = parts.flatMap { $0.partsValue } self.init(role: role, parts: content) @@ -117,10 +140,15 @@ public struct ModelContent: Equatable, Sendable { // MARK: Codable Conformances @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ModelContent: Codable {} +extension ModelContent: Codable { + enum CodingKeys: String, CodingKey { + case role + case internalParts = "parts" + } +} @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ModelContent.Part: Codable { +extension ModelContent.InternalPart: Codable { enum CodingKeys: String, CodingKey { case text case inlineData @@ -129,35 +157,15 @@ extension ModelContent.Part: Codable { case functionResponse } - enum InlineDataKeys: String, CodingKey { - case mimeType = "mime_type" - case bytes = "data" - } - - enum FileDataKeys: String, CodingKey { - case mimeType = "mime_type" - case uri = "file_uri" - } - public func encode(to encoder: Encoder) throws { var container = encoder.container(keyedBy: CodingKeys.self) switch self { - case let .text(a0): - try container.encode(a0, forKey: .text) + case let .text(text): + try container.encode(text, forKey: .text) case let .inlineData(mimetype, bytes): - var inlineDataContainer = container.nestedContainer( - keyedBy: InlineDataKeys.self, - forKey: .inlineData - ) - try inlineDataContainer.encode(mimetype, forKey: .mimeType) - try inlineDataContainer.encode(bytes, forKey: .bytes) + try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData) case let .fileData(mimetype: mimetype, url): - var fileDataContainer = container.nestedContainer( - keyedBy: FileDataKeys.self, - forKey: .fileData - ) - try fileDataContainer.encode(mimetype, forKey: .mimeType) - try fileDataContainer.encode(url, forKey: .uri) + try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData) case let .functionCall(functionCall): try container.encode(functionCall, forKey: .functionCall) case let .functionResponse(functionResponse): @@ -170,13 +178,11 @@ extension ModelContent.Part: Codable { if values.contains(.text) { self = try .text(values.decode(String.self, forKey: .text)) } else if values.contains(.inlineData) { - let dataContainer = try values.nestedContainer( - keyedBy: InlineDataKeys.self, - forKey: .inlineData - ) - let mimetype = try dataContainer.decode(String.self, forKey: .mimeType) - let bytes = try dataContainer.decode(Data.self, forKey: .bytes) - self = .inlineData(mimetype: mimetype, bytes) + let inlineData = try values.decode(InlineData.self, forKey: .inlineData) + self = .inlineData(mimetype: inlineData.mimeType, inlineData.data) + } else if values.contains(.fileData) { + let fileData = try values.decode(FileData.self, forKey: .fileData) + self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI) } else if values.contains(.functionCall) { self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall)) } else if values.contains(.functionResponse) { @@ -185,7 +191,7 @@ extension ModelContent.Part: Codable { let unexpectedKeys = values.allKeys.map { $0.stringValue } throw DecodingError.dataCorrupted(DecodingError.Context( codingPath: values.codingPath, - debugDescription: "Unexpected ModelContent.Part type(s): \(unexpectedKeys)" + debugDescription: "Unexpected Part type(s): \(unexpectedKeys)" )) } } diff --git a/FirebaseVertexAI/Sources/PartsRepresentable+Image.swift b/FirebaseVertexAI/Sources/PartsRepresentable+Image.swift index 6b2cc977889..24d11be2c46 100644 --- a/FirebaseVertexAI/Sources/PartsRepresentable+Image.swift +++ b/FirebaseVertexAI/Sources/PartsRepresentable+Image.swift @@ -36,31 +36,31 @@ enum ImageConversionError: Error { } #if canImport(UIKit) - /// Enables images to be representable as ``ThrowingPartsRepresentable``. + /// Enables images to be representable as ``PartsRepresentable``. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) - extension UIImage: ThrowingPartsRepresentable { - public func tryPartsValue() throws -> [ModelContent.Part] { + extension UIImage: PartsRepresentable { + public var partsValue: [any Part] { guard let data = jpegData(compressionQuality: imageCompressionQuality) else { - throw ImageConversionError.couldNotConvertToJPEG + return [ErrorPart(ImageConversionError.couldNotConvertToJPEG)] } - return [ModelContent.Part.inlineData(mimetype: "image/jpeg", data)] + return [InlineDataPart(data: data, mimeType: "image/jpeg")] } } #elseif canImport(AppKit) - /// Enables images to be representable as ``ThrowingPartsRepresentable``. + /// Enables images to be representable as ``PartsRepresentable``. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) - extension NSImage: ThrowingPartsRepresentable { - public func tryPartsValue() throws -> [ModelContent.Part] { + extension NSImage: PartsRepresentable { + public var partsValue: [any Part] { guard let cgImage = cgImage(forProposedRect: nil, context: nil, hints: nil) else { - throw ImageConversionError.invalidUnderlyingImage + return [ErrorPart(ImageConversionError.invalidUnderlyingImage)] } let bmp = NSBitmapImageRep(cgImage: cgImage) guard let data = bmp.representation(using: .jpeg, properties: [.compressionFactor: 0.8]) else { - throw ImageConversionError.couldNotConvertToJPEG + return [ErrorPart(ImageConversionError.couldNotConvertToJPEG)] } - return [ModelContent.Part.inlineData(mimetype: "image/jpeg", data)] + return [InlineDataPart(data: data, mimeType: "image/jpeg")] } } #endif @@ -68,22 +68,22 @@ enum ImageConversionError: Error { #if !os(watchOS) // This code does not build on watchOS. /// Enables `CGImages` to be representable as model content. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, *) - extension CGImage: ThrowingPartsRepresentable { - public func tryPartsValue() throws -> [ModelContent.Part] { + extension CGImage: PartsRepresentable { + public var partsValue: [any Part] { let output = NSMutableData() guard let imageDestination = CGImageDestinationCreateWithData( output, UTType.jpeg.identifier as CFString, 1, nil ) else { - throw ImageConversionError.couldNotAllocateDestination + return [ErrorPart(ImageConversionError.couldNotAllocateDestination)] } CGImageDestinationAddImage(imageDestination, self, nil) CGImageDestinationSetProperties(imageDestination, [ kCGImageDestinationLossyCompressionQuality: imageCompressionQuality, ] as CFDictionary) if CGImageDestinationFinalize(imageDestination) { - return [.inlineData(mimetype: "image/jpeg", output as Data)] + return [InlineDataPart(data: output as Data, mimeType: "image/jpeg")] } - throw ImageConversionError.couldNotConvertToJPEG + return [ErrorPart(ImageConversionError.couldNotConvertToJPEG)] } } #endif // !os(watchOS) @@ -91,8 +91,8 @@ enum ImageConversionError: Error { #if canImport(CoreImage) /// Enables `CIImages` to be representable as model content. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, *) - extension CIImage: ThrowingPartsRepresentable { - public func tryPartsValue() throws -> [ModelContent.Part] { + extension CIImage: PartsRepresentable { + public var partsValue: [any Part] { let context = CIContext() let jpegData = (colorSpace ?? CGColorSpace(name: CGColorSpace.sRGB)) .flatMap { @@ -102,9 +102,9 @@ enum ImageConversionError: Error { context.jpegRepresentation(of: self, colorSpace: $0, options: [:]) } if let jpegData = jpegData { - return [.inlineData(mimetype: "image/jpeg", jpegData)] + return [InlineDataPart(data: jpegData, mimeType: "image/jpeg")] } - throw ImageConversionError.couldNotConvertToJPEG + return [ErrorPart(ImageConversionError.couldNotConvertToJPEG)] } } #endif // canImport(CoreImage) diff --git a/FirebaseVertexAI/Sources/PartsRepresentable.swift b/FirebaseVertexAI/Sources/PartsRepresentable.swift index 7b9d9524b67..6ef63d3f182 100644 --- a/FirebaseVertexAI/Sources/PartsRepresentable.swift +++ b/FirebaseVertexAI/Sources/PartsRepresentable.swift @@ -15,52 +15,33 @@ import Foundation /// A protocol describing any data that could be serialized to model-interpretable input data, -/// where the serialization process might fail with an error. +/// where the serialization process cannot fail with an error. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public protocol ThrowingPartsRepresentable { - func tryPartsValue() throws -> [ModelContent.Part] +public protocol PartsRepresentable { + var partsValue: [any Part] { get } } -/// A protocol describing any data that could be serialized to model-interpretable input data, -/// where the serialization process cannot fail with an error. For a failable conversion, see -/// ``ThrowingPartsRepresentable`` -@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public protocol PartsRepresentable: ThrowingPartsRepresentable { - var partsValue: [ModelContent.Part] { get } -} - -@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public extension PartsRepresentable { - func tryPartsValue() throws -> [ModelContent.Part] { - return partsValue - } -} - -/// Enables a ``ModelContent.Part`` to be passed in as ``ThrowingPartsRepresentable``. +/// Enables a ``Part`` to be used as a ``PartsRepresentable``. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ModelContent.Part: ThrowingPartsRepresentable { - public typealias ErrorType = Never - public func tryPartsValue() throws -> [ModelContent.Part] { +public extension Part { + var partsValue: [any Part] { return [self] } } -/// Enable an `Array` of ``ThrowingPartsRepresentable`` values to be passed in as a single -/// ``ThrowingPartsRepresentable``. +/// Enable an `Array` of ``PartsRepresentable`` values to be passed in as a single +/// ``PartsRepresentable``. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension [ThrowingPartsRepresentable]: ThrowingPartsRepresentable { - public func tryPartsValue() throws -> [ModelContent.Part] { - return try compactMap { element in - try element.tryPartsValue() - } - .flatMap { $0 } +extension [PartsRepresentable]: PartsRepresentable { + public var partsValue: [any Part] { + return flatMap { $0.partsValue } } } -/// Enables a `String` to be passed in as ``ThrowingPartsRepresentable``. +/// Enables a `String` to be passed in as ``PartsRepresentable``. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) extension String: PartsRepresentable { - public var partsValue: [ModelContent.Part] { - return [.text(self)] + public var partsValue: [any Part] { + return [TextPart(self)] } } diff --git a/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift b/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift new file mode 100644 index 00000000000..8a62ae4fdd9 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift @@ -0,0 +1,103 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct InlineData: Codable, Equatable, Sendable { + let mimeType: String + let data: Data + + init(data: Data, mimeType: String) { + self.data = data + self.mimeType = mimeType + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct FileData: Codable, Equatable, Sendable { + let fileURI: String + let mimeType: String + + init(fileURI: String, mimeType: String) { + self.fileURI = fileURI + self.mimeType = mimeType + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct FunctionCall: Equatable, Sendable { + let name: String + let args: JSONObject + + init(name: String, args: JSONObject) { + self.name = name + self.args = args + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct FunctionResponse: Codable, Equatable, Sendable { + let name: String + let response: JSONObject + + init(name: String, response: JSONObject) { + self.name = name + self.response = response + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct ErrorPart: Part, Error { + let error: Error + + init(_ error: Error) { + self.error = error + } +} + +// MARK: - Codable Conformances + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension FunctionCall: Codable { + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + name = try container.decode(String.self, forKey: .name) + if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) { + self.args = args + } else { + args = JSONObject() + } + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ErrorPart: Codable { + init(from decoder: any Decoder) throws { + fatalError("Decoding an ErrorPart is not supported.") + } + + func encode(to encoder: any Encoder) throws { + fatalError("Encoding an ErrorPart is not supported.") + } +} + +// MARK: - Equatable Conformances + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ErrorPart: Equatable { + static func == (lhs: ErrorPart, rhs: ErrorPart) -> Bool { + fatalError("Comparing ErrorParts for equality is not supported.") + } +} diff --git a/FirebaseVertexAI/Sources/Types/Public/Part.swift b/FirebaseVertexAI/Sources/Types/Public/Part.swift new file mode 100644 index 00000000000..1eba33ae018 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Part.swift @@ -0,0 +1,134 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A discrete piece of data in a media format interpretable by an AI model. +/// +/// Within a single value of ``Part``, different data types may not mix. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public protocol Part: PartsRepresentable, Codable, Sendable, Equatable {} + +/// A text part containing a string value. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct TextPart: Part { + /// Text value. + public let text: String + + public init(_ text: String) { + self.text = text + } +} + +/// Data with a specified media type. +/// +/// > Note: Not all media types may be supported by the AI model. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct InlineDataPart: Part { + let inlineData: InlineData + + public var data: Data { inlineData.data } + public var mimeType: String { inlineData.mimeType } + + public init(data: Data, mimeType: String) { + self.init(InlineData(data: data, mimeType: mimeType)) + } + + init(_ inlineData: InlineData) { + self.inlineData = inlineData + } +} + +/// File data stored in Cloud Storage for Firebase, referenced by URI. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct FileDataPart: Part { + let fileData: FileData + + public var uri: String { fileData.fileURI } + public var mimeType: String { fileData.mimeType } + + /// Constructs a new file data part. + /// + /// - Parameters: + /// - uri: The `"gs://"`-prefixed URI of the file in Cloud Storage for Firebase, for example, + /// `"gs://bucket-name/path/image.jpg"`. + /// - mimeType: The IANA standard MIME type of the uploaded file, for example, `"image/jpeg"` + /// or `"video/mp4"`; see [media requirements + /// ](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#media_requirements) + /// for supported values. + public init(uri: String, mimeType: String) { + self.init(FileData(fileURI: uri, mimeType: mimeType)) + } + + init(_ fileData: FileData) { + self.fileData = fileData + } +} + +/// A predicted function call returned from the model. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct FunctionCallPart: Part { + let functionCall: FunctionCall + + /// The name of the function to call. + public var name: String { functionCall.name } + + /// The function parameters and values. + public var args: JSONObject { functionCall.args } + + /// Constructs a new function call part. + /// + /// > Note: A `FunctionCallPart` is typically received from the model, rather than created + /// manually. + /// + /// - Parameters: + /// - name: The name of the function to call. + /// - args: The function parameters and values. + public init(name: String, args: JSONObject) { + self.init(FunctionCall(name: name, args: args)) + } + + init(_ functionCall: FunctionCall) { + self.functionCall = functionCall + } +} + +/// Result output from a ``FunctionCall``. +/// +/// Contains a string representing the `FunctionDeclaration.name` and a structured JSON object +/// containing any output from the function is used as context to the model. This should contain the +/// result of a ``FunctionCall`` made based on model prediction. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct FunctionResponsePart: Part { + let functionResponse: FunctionResponse + + /// The name of the function that was called. + public var name: String { functionResponse.name } + + /// The function's response or return value. + public var response: JSONObject { functionResponse.response } + + /// Constructs a new `FunctionResponse`. + /// + /// - Parameters: + /// - name: The name of the function that was called. + /// - response: The function's response. + public init(name: String, response: JSONObject) { + self.init(FunctionResponse(name: name, response: response)) + } + + init(_ functionResponse: FunctionResponse) { + self.functionResponse = functionResponse + } +} diff --git a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift index 8eddf79a648..d2789c574df 100644 --- a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift @@ -96,12 +96,12 @@ final class IntegrationTests: XCTestCase { } func testCountTokens_image_fileData() async throws { - let fileData = ModelContent(parts: [.fileData( - mimetype: "image/jpeg", - uri: "gs://ios-opensource-samples.appspot.com/ios/public/blank.jpg" - )]) + let fileData = FileDataPart( + uri: "gs://ios-opensource-samples.appspot.com/ios/public/blank.jpg", + mimeType: "image/jpeg" + ) - let response = try await model.countTokens([fileData]) + let response = try await model.countTokens(fileData) XCTAssertEqual(response.totalTokens, 266) XCTAssertEqual(response.totalBillableCharacters, 35) @@ -118,13 +118,13 @@ final class IntegrationTests: XCTestCase { tools: [Tool(functionDeclarations: [sumDeclaration])] ) let prompt = "What is 10 + 32?" - let sumCall = FunctionCall(name: "sum", args: ["x": .number(10), "y": .number(32)]) - let sumResponse = FunctionResponse(name: "sum", response: ["result": .number(42)]) + let sumCall = FunctionCallPart(name: "sum", args: ["x": .number(10), "y": .number(32)]) + let sumResponse = FunctionResponsePart(name: "sum", response: ["result": .number(42)]) let response = try await model.countTokens([ - ModelContent(role: "user", parts: [.text(prompt)]), - ModelContent(role: "model", parts: [.functionCall(sumCall)]), - ModelContent(role: "function", parts: [.functionResponse(sumResponse)]), + ModelContent(role: "user", parts: prompt), + ModelContent(role: "model", parts: sumCall), + ModelContent(role: "function", parts: sumResponse), ]) XCTAssertEqual(response.totalTokens, 24) diff --git a/FirebaseVertexAI/Tests/Unit/ChatTests.swift b/FirebaseVertexAI/Tests/Unit/ChatTests.swift index 614559fe011..95ce8e7e43d 100644 --- a/FirebaseVertexAI/Tests/Unit/ChatTests.swift +++ b/FirebaseVertexAI/Tests/Unit/ChatTests.swift @@ -77,11 +77,12 @@ final class ChatTests: XCTestCase { } XCTAssertEqual(chat.history.count, 2) - XCTAssertEqual(chat.history[0].parts[0].text, input) + let part = try XCTUnwrap(chat.history[0].parts[0]) + let textPart = try XCTUnwrap(part as? TextPart) + XCTAssertEqual(textPart.text, input) let finalText = "1 2 3 4 5 6 7 8" let assembledExpectation = ModelContent(role: "model", parts: finalText) - XCTAssertEqual(chat.history[0].parts[0].text, input) XCTAssertEqual(chat.history[1], assembledExpectation) } } diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index c5e8332d2b8..86d3e7f9c11 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -72,7 +72,7 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - let partText = try XCTUnwrap(part.text) + let partText = try XCTUnwrap(part as? TextPart).text XCTAssertTrue(partText.hasPrefix("1. **Use Freshly Ground Coffee**:")) XCTAssertEqual(response.text, partText) XCTAssertEqual(response.functionCalls, []) @@ -94,8 +94,9 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - XCTAssertEqual(part.text, "Mountain View, California") - XCTAssertEqual(response.text, part.text) + let textPart = try XCTUnwrap(part as? TextPart) + XCTAssertEqual(textPart.text, "Mountain View, California") + XCTAssertEqual(response.text, textPart.text) XCTAssertEqual(response.functionCalls, []) } @@ -150,9 +151,9 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - let partText = try XCTUnwrap(part.text) - XCTAssertTrue(partText.hasPrefix("Google")) - XCTAssertEqual(response.text, part.text) + let textPart = try XCTUnwrap(part as? TextPart) + XCTAssertTrue(textPart.text.hasPrefix("Google")) + XCTAssertEqual(response.text, textPart.text) let promptFeedback = try XCTUnwrap(response.promptFeedback) XCTAssertNil(promptFeedback.blockReason) XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible) @@ -211,7 +212,7 @@ final class GenerativeModelTests: XCTestCase { let candidate = try XCTUnwrap(response.candidates.first) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - guard case let .functionCall(functionCall) = part else { + guard let functionCall = part as? FunctionCallPart else { XCTFail("Part is not a FunctionCall.") return } @@ -233,7 +234,7 @@ final class GenerativeModelTests: XCTestCase { let candidate = try XCTUnwrap(response.candidates.first) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - guard case let .functionCall(functionCall) = part else { + guard let functionCall = part as? FunctionCallPart else { XCTFail("Part is not a FunctionCall.") return } @@ -255,7 +256,7 @@ final class GenerativeModelTests: XCTestCase { let candidate = try XCTUnwrap(response.candidates.first) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - guard case let .functionCall(functionCall) = part else { + guard let functionCall = part as? FunctionCallPart else { XCTFail("Part is not a FunctionCall.") return } @@ -1282,10 +1283,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "json" ) - let response = try await model.countTokens(ModelContent.Part.inlineData( - mimetype: "image/jpeg", - Data() - )) + let response = try await model.countTokens(InlineDataPart(data: Data(), mimeType: "image/jpeg")) XCTAssertEqual(response.totalTokens, 258) XCTAssertNil(response.totalBillableCharacters) diff --git a/FirebaseVertexAI/Tests/Unit/ModelContentTests.swift b/FirebaseVertexAI/Tests/Unit/ModelContentTests.swift deleted file mode 100644 index 67175af739b..00000000000 --- a/FirebaseVertexAI/Tests/Unit/ModelContentTests.swift +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import Foundation -import XCTest - -@testable import FirebaseVertexAI - -@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -final class ModelContentTests: XCTestCase { - let decoder = JSONDecoder() - let encoder = JSONEncoder() - - override func setUp() { - encoder.outputFormatting = .init( - arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes - ) - } - - // MARK: - ModelContent.Part Decoding - - func testDecodeFunctionResponsePart() throws { - let functionName = "test-function-name" - let resultParameter = "test-result-parameter" - let resultValue = "test-result-value" - let json = """ - { - "functionResponse" : { - "name" : "\(functionName)", - "response" : { - "\(resultParameter)" : "\(resultValue)" - } - } - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - let part = try decoder.decode(ModelContent.Part.self, from: jsonData) - - guard case let .functionResponse(functionResponse) = part else { - XCTFail("Decoded Part was not a FunctionResponse.") - return - } - XCTAssertEqual(functionResponse.name, functionName) - XCTAssertEqual(functionResponse.response, [resultParameter: .string(resultValue)]) - } - - // MARK: - ModelContent.Part Encoding - - func testEncodeFileDataPart() throws { - let mimeType = "image/jpeg" - let fileURI = "gs://test-bucket/image.jpg" - let fileDataPart = ModelContent.Part.fileData(mimetype: mimeType, uri: fileURI) - - let jsonData = try encoder.encode(fileDataPart) - - let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) - XCTAssertEqual(json, """ - { - "fileData" : { - "file_uri" : "\(fileURI)", - "mime_type" : "\(mimeType)" - } - } - """) - } -} diff --git a/FirebaseVertexAI/Tests/Unit/PartTests.swift b/FirebaseVertexAI/Tests/Unit/PartTests.swift new file mode 100644 index 00000000000..35c3b441d9f --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/PartTests.swift @@ -0,0 +1,157 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import XCTest + +@testable import FirebaseVertexAI + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class PartTests: XCTestCase { + let decoder = JSONDecoder() + let encoder = JSONEncoder() + + override func setUp() { + encoder.outputFormatting = .init( + arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes + ) + } + + // MARK: - Part Decoding + + func testDecodeTextPart() throws { + let expectedText = "Hello, world!" + let json = """ + { + "text" : "\(expectedText)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(TextPart.self, from: jsonData) + + XCTAssertEqual(part.text, expectedText) + } + + func testDecodeInlineDataPart() throws { + let imageBase64 = try PartTests.blueSquareImage() + let mimeType = "image/png" + let json = """ + { + "inlineData" : { + "data" : "\(imageBase64)", + "mimeType" : "\(mimeType)" + } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InlineDataPart.self, from: jsonData) + + XCTAssertEqual(part.data, Data(base64Encoded: imageBase64)) + XCTAssertEqual(part.mimeType, mimeType) + } + + func testDecodeFunctionResponsePart() throws { + let functionName = "test-function-name" + let resultParameter = "test-result-parameter" + let resultValue = "test-result-value" + let json = """ + { + "functionResponse" : { + "name" : "\(functionName)", + "response" : { + "\(resultParameter)" : "\(resultValue)" + } + } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(FunctionResponsePart.self, from: jsonData) + + let functionResponse = part.functionResponse + XCTAssertEqual(functionResponse.name, functionName) + XCTAssertEqual(functionResponse.response, [resultParameter: .string(resultValue)]) + } + + // MARK: - Part Encoding + + func testEncodeTextPart() throws { + let expectedText = "Hello, world!" + let textPart = TextPart(expectedText) + + let jsonData = try encoder.encode(textPart) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "text" : "\(expectedText)" + } + """) + } + + func testEncodeInlineDataPart() throws { + let mimeType = "image/png" + let imageBase64 = try PartTests.blueSquareImage() + let imageBase64Data = Data(base64Encoded: imageBase64) + let inlineDataPart = InlineDataPart(data: imageBase64Data!, mimeType: mimeType) + + let jsonData = try encoder.encode(inlineDataPart) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "inlineData" : { + "data" : "\(imageBase64)", + "mimeType" : "\(mimeType)" + } + } + """) + } + + func testEncodeFileDataPart() throws { + let mimeType = "image/jpeg" + let fileURI = "gs://test-bucket/image.jpg" + let fileDataPart = FileDataPart(uri: fileURI, mimeType: mimeType) + + let jsonData = try encoder.encode(fileDataPart) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "fileData" : { + "fileURI" : "\(fileURI)", + "mimeType" : "\(mimeType)" + } + } + """) + } + + // MARK: - Helpers + + private static func bundle() -> Bundle { + #if SWIFT_PACKAGE + return Bundle.module + #else // SWIFT_PACKAGE + return Bundle(for: Self.self) + #endif // SWIFT_PACKAGE + } + + private static func blueSquareImage() throws -> String { + let imageURL = bundle().url(forResource: "blue", withExtension: "png")! + let imageData = try Data(contentsOf: imageURL) + return imageData.base64EncodedString() + } +} diff --git a/FirebaseVertexAI/Tests/Unit/PartsRepresentableTests.swift b/FirebaseVertexAI/Tests/Unit/PartsRepresentableTests.swift index 073f6582721..859b77f58c7 100644 --- a/FirebaseVertexAI/Tests/Unit/PartsRepresentableTests.swift +++ b/FirebaseVertexAI/Tests/Unit/PartsRepresentableTests.swift @@ -44,7 +44,7 @@ final class PartsRepresentableTests: XCTestCase { )! return ctx.makeImage()! } - let modelContent = try image.tryPartsValue() + let modelContent = image.partsValue XCTAssert(modelContent.count > 0, "Expected non-empty model content for CGImage: \(image)") } #endif // !os(watchOS) @@ -53,22 +53,22 @@ final class PartsRepresentableTests: XCTestCase { func testModelContentFromCIImageIsNotEmpty() throws { let image = CIImage(color: CIColor.red) .cropped(to: CGRect(origin: CGPointZero, size: CGSize(width: 16, height: 16))) - let modelContent = try image.tryPartsValue() + let modelContent = image.partsValue XCTAssert(modelContent.count > 0, "Expected non-empty model content for CGImage: \(image)") } func testModelContentFromInvalidCIImageThrows() throws { let image = CIImage.empty() - do { - _ = try image.tryPartsValue() - XCTFail("Expected model content from invalid image to error") - } catch let imageError as ImageConversionError { - guard case .couldNotConvertToJPEG = imageError else { - XCTFail("Expected JPEG conversion error, got \(imageError) instead.") - return - } - } catch { - XCTFail("Got unexpected error type: \(error)") + let modelContent = image.partsValue + let part = try XCTUnwrap(modelContent.first) + let errorPart = try XCTUnwrap(part as? ErrorPart, "Expected ErrorPart.") + let imageError = try XCTUnwrap( + errorPart.error as? ImageConversionError, + "Got unexpected error type: \(errorPart.error)" + ) + guard case .couldNotConvertToJPEG = imageError else { + XCTFail("Expected JPEG conversion error, got \(imageError) instead.") + return } } #endif // canImport(CoreImage) @@ -76,22 +76,22 @@ final class PartsRepresentableTests: XCTestCase { #if canImport(UIKit) && !os(visionOS) // These tests are stalling in CI on visionOS. func testModelContentFromInvalidUIImageThrows() throws { let image = UIImage() - do { - _ = try image.tryPartsValue() - XCTFail("Expected model content from invalid image to error") - } catch let imageError as ImageConversionError { - guard case .couldNotConvertToJPEG = imageError else { - XCTFail("Expected JPEG conversion error, got \(imageError) instead.") - return - } - } catch { - XCTFail("Got unexpected error type: \(error)") + let modelContent = image.partsValue + let part = try XCTUnwrap(modelContent.first) + let errorPart = try XCTUnwrap(part as? ErrorPart, "Expected ErrorPart.") + let imageError = try XCTUnwrap( + errorPart.error as? ImageConversionError, + "Got unexpected error type: \(errorPart.error)" + ) + guard case .couldNotConvertToJPEG = imageError else { + XCTFail("Expected JPEG conversion error, got \(imageError) instead.") + return } } func testModelContentFromUIImageIsNotEmpty() throws { let image = try XCTUnwrap(UIImage(systemName: "star.fill")) - let modelContent = try image.tryPartsValue() + let modelContent = image.partsValue XCTAssert(modelContent.count > 0, "Expected non-empty model content for UIImage: \(image)") } @@ -102,29 +102,23 @@ final class PartsRepresentableTests: XCTestCase { let rep = NSCIImageRep(ciImage: coreImage) let image = NSImage(size: rep.size) image.addRepresentation(rep) - let modelContent = try image.tryPartsValue() + let modelContent = image.partsValue XCTAssert(modelContent.count > 0, "Expected non-empty model content for NSImage: \(image)") } func testModelContentFromInvalidNSImageThrows() throws { let image = NSImage() - do { - _ = try image.tryPartsValue() - } catch { - guard let imageError = (error as? ImageConversionError) else { - XCTFail("Got unexpected error type: \(error)") - return - } - switch imageError { - case .invalidUnderlyingImage: - // Pass - return - case _: - XCTFail("Expected image conversion error, got \(imageError) instead") - return - } + let modelContent = image.partsValue + let part = try XCTUnwrap(modelContent.first) + let errorPart = try XCTUnwrap(part as? ErrorPart, "Expected ErrorPart.") + let imageError = try XCTUnwrap( + errorPart.error as? ImageConversionError, + "Got unexpected error type: \(errorPart.error)" + ) + guard case .invalidUnderlyingImage = imageError else { + XCTFail("Expected invalid underyling image conversion error, got \(imageError) instead.") + return } - XCTFail("Expected model content from invalid image to error") } #endif } diff --git a/FirebaseVertexAI/Tests/Unit/Resources/blue.png b/FirebaseVertexAI/Tests/Unit/Resources/blue.png new file mode 100644 index 00000000000..a0cf28c6edb Binary files /dev/null and b/FirebaseVertexAI/Tests/Unit/Resources/blue.png differ diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index 1c469867f76..c187a997b6c 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -33,7 +33,10 @@ final class VertexAIAPITests: XCTestCase { stopSequences: ["..."], responseMIMEType: "text/plain") let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] - let systemInstruction = ModelContent(role: "system", parts: [.text("Talk like a pirate.")]) + let systemInstruction = ModelContent( + role: "system", + parts: TextPart("Talk like a pirate.") + ) // Instantiate Vertex AI SDK - Default App let vertexAI = VertexAI.vertexAI() @@ -72,11 +75,13 @@ final class VertexAIAPITests: XCTestCase { // Full Typed Usage let pngData = Data() // .... - let contents = [ModelContent(role: "user", - parts: [ - .text("Is it a cat?"), - .png(pngData), - ])] + let contents = [ModelContent( + role: "user", + parts: [ + TextPart("Is it a cat?"), + InlineDataPart(data: pngData, mimeType: "image/png"), + ] + )] do { let response = try await genAI.generateContent(contents) @@ -93,13 +98,12 @@ final class VertexAIAPITests: XCTestCase { let _ = try await genAI.generateContent(str, "abc", "def") let _ = try await genAI.generateContent( str, - ModelContent.Part.fileData(mimetype: "image/jpeg", uri: "gs://test-bucket/image.jpg") + FileDataPart(uri: "gs://test-bucket/image.jpg", mimeType: "image/jpeg") ) #if canImport(UIKit) _ = try await genAI.generateContent(UIImage()) _ = try await genAI.generateContent([UIImage()]) - _ = try await genAI - .generateContent([str, UIImage(), ModelContent.Part.text(str)]) + _ = try await genAI.generateContent([str, UIImage(), TextPart(str)]) _ = try await genAI.generateContent(str, UIImage(), "def", UIImage()) _ = try await genAI.generateContent([str, UIImage(), "def", UIImage()]) _ = try await genAI.generateContent([ModelContent("def", UIImage()), @@ -111,51 +115,43 @@ final class VertexAIAPITests: XCTestCase { _ = try await genAI.generateContent([str, NSImage(), "def", NSImage()]) #endif - // ThrowingPartsRepresentable combinations. - let _ = ModelContent(parts: [.text(str)]) - let _ = ModelContent(role: "model", parts: [.text(str)]) + // PartsRepresentable combinations. + let _ = ModelContent(parts: [TextPart(str)]) + let _ = ModelContent(role: "model", parts: [TextPart(str)]) let _ = ModelContent(parts: "Constant String") let _ = ModelContent(parts: str) - // Note: This requires the `try` for some reason. Casting to explicit [PartsRepresentable] also - // doesn't work. - let _ = try ModelContent(parts: [str]) - // Note: without `as [any ThrowingPartsRepresentable]` this will fail to compile with "Cannot + let _ = ModelContent(parts: [str]) + // Note: without `as [any PartsRepresentable]` this will fail to compile with "Cannot // convert value of type 'String' to expected element type - // 'Array.ArrayLiteralElement'. Not sure if there's a way we can get it to + // 'Array.ArrayLiteralElement'. Not sure if there's a way we can get it to // work. - let _ = try ModelContent(parts: [str, ModelContent.Part.inlineData( - mimetype: "foo", - Data() - )] as [any ThrowingPartsRepresentable]) + let _ = ModelContent( + parts: [str, InlineDataPart(data: Data(), mimeType: "foo")] as [any PartsRepresentable] + ) #if canImport(UIKit) - _ = try ModelContent(role: "user", parts: UIImage()) - _ = try ModelContent(role: "user", parts: [UIImage()]) - // Note: without `as [any ThrowingPartsRepresentable]` this will fail to compile with "Cannot - // convert - // value of type `[Any]` to expected type `[any ThrowingPartsRepresentable]`. Not sure if - // there's a + _ = ModelContent(role: "user", parts: UIImage()) + _ = ModelContent(role: "user", parts: [UIImage()]) + // Note: without `as [any PartsRepresentable]` this will fail to compile with "Cannot convert + // value of type `[Any]` to expected type `[any PartsRepresentable]`. Not sure if there's a // way we can get it to work. - _ = try ModelContent(parts: [str, UIImage()] as [any ThrowingPartsRepresentable]) + _ = ModelContent(parts: [str, UIImage()] as [any PartsRepresentable]) // Alternatively, you can explicitly declare the type in a variable and pass it in. - let representable2: [any ThrowingPartsRepresentable] = [str, UIImage()] - _ = try ModelContent(parts: representable2) - _ = try ModelContent(parts: [str, UIImage(), - ModelContent.Part.text(str)] as [any ThrowingPartsRepresentable]) + let representable2: [any PartsRepresentable] = [str, UIImage()] + _ = ModelContent(parts: representable2) + _ = + ModelContent(parts: [str, UIImage(), TextPart(str)] as [any PartsRepresentable]) #elseif canImport(AppKit) - _ = try ModelContent(role: "user", parts: NSImage()) - _ = try ModelContent(role: "user", parts: [NSImage()]) - // Note: without `as [any ThrowingPartsRepresentable]` this will fail to compile with "Cannot - // convert - // value of type `[Any]` to expected type `[any ThrowingPartsRepresentable]`. Not sure if - // there's a + _ = ModelContent(role: "user", parts: NSImage()) + _ = ModelContent(role: "user", parts: [NSImage()]) + // Note: without `as [any PartsRepresentable]` this will fail to compile with "Cannot convert + // value of type `[Any]` to expected type `[any PartsRepresentable]`. Not sure if there's a // way we can get it to work. - _ = try ModelContent(parts: [str, NSImage()] as [any ThrowingPartsRepresentable]) + _ = ModelContent(parts: [str, NSImage()] as [any PartsRepresentable]) // Alternatively, you can explicitly declare the type in a variable and pass it in. - let representable2: [any ThrowingPartsRepresentable] = [str, NSImage()] - _ = try ModelContent(parts: representable2) + let representable2: [any PartsRepresentable] = [str, NSImage()] + _ = ModelContent(parts: representable2) _ = - try ModelContent(parts: [str, NSImage(), - ModelContent.Part.text(str)] as [any ThrowingPartsRepresentable]) + ModelContent(parts: [str, NSImage(), TextPart(str)] as [any PartsRepresentable]) #endif // countTokens API @@ -189,7 +185,7 @@ final class VertexAIAPITests: XCTestCase { // Computed Properties let _: String? = response.text - let _: [FunctionCall] = response.functionCalls + let _: [FunctionCallPart] = response.functionCalls } // Result builder alternative diff --git a/Package.swift b/Package.swift index 4bfb5a9b228..f35d72bc623 100644 --- a/Package.swift +++ b/Package.swift @@ -1313,6 +1313,7 @@ let package = Package( path: "FirebaseVertexAI/Tests/Unit", resources: [ .process("vertexai-sdk-test-data/mock-responses"), + .process("Resources"), ], cSettings: [ .headerSearchPath("../../../"),