diff --git a/FirebaseAI/Sources/TemplateChatSession.swift b/FirebaseAI/Sources/TemplateChatSession.swift index abba669a1dd..a5979ac472e 100644 --- a/FirebaseAI/Sources/TemplateChatSession.swift +++ b/FirebaseAI/Sources/TemplateChatSession.swift @@ -52,15 +52,14 @@ final class TemplateChatSession: Sendable { /// - Returns: The content generated by the model. /// - Throws: A ``GenerateContentError`` if the request failed. func sendMessage(_ content: [ModelContent], - inputs: [String: Any], + inputs: [String: any TemplateInputRepresentable], options: RequestOptions = RequestOptions()) async throws -> GenerateContentResponse { - let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } let newContent = content.map(populateContentRole) let response = try await model.generateContentWithHistory( history: _history.history + newContent, template: templateID, - inputs: templateInputs, + inputs: inputs.mapValues { $0.templateInputRepresentation }, options: options ) _history.append(contentsOf: newContent) @@ -82,7 +81,7 @@ final class TemplateChatSession: Sendable { /// - Returns: The content generated by the model. /// - Throws: A ``GenerateContentError`` if the request failed. func sendMessage(_ message: any PartsRepresentable, - inputs: [String: Any], + inputs: [String: any TemplateInputRepresentable], options: RequestOptions = RequestOptions()) async throws -> GenerateContentResponse { return try await sendMessage([ModelContent(parts: message.partsValue)], @@ -103,15 +102,14 @@ final class TemplateChatSession: Sendable { /// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects. /// - Throws: A ``GenerateContentError`` if the request failed. func sendMessageStream(_ content: [ModelContent], - inputs: [String: Any], + inputs: [String: any TemplateInputRepresentable], options: RequestOptions = RequestOptions()) throws -> AsyncThrowingStream { - let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } let newContent = content.map(populateContentRole) let stream = try model.generateContentStreamWithHistory( history: _history.history + newContent, template: templateID, - inputs: templateInputs, + inputs: inputs.mapValues { $0.templateInputRepresentation }, options: options ) return AsyncThrowingStream { continuation in @@ -158,7 +156,7 @@ final class TemplateChatSession: Sendable { /// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects. /// - Throws: A ``GenerateContentError`` if the request failed. func sendMessageStream(_ message: any PartsRepresentable, - inputs: [String: Any], + inputs: [String: any TemplateInputRepresentable], options: RequestOptions = RequestOptions()) throws -> AsyncThrowingStream { return try sendMessageStream([ModelContent(parts: message.partsValue)], diff --git a/FirebaseAI/Sources/TemplateGenerateContentRequest.swift b/FirebaseAI/Sources/TemplateGenerateContentRequest.swift index 20ba84b3571..e6430abea2b 100644 --- a/FirebaseAI/Sources/TemplateGenerateContentRequest.swift +++ b/FirebaseAI/Sources/TemplateGenerateContentRequest.swift @@ -34,7 +34,7 @@ extension TemplateGenerateContentRequest: Encodable { func encode(to encoder: any Encoder) throws { var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode(inputs, forKey: .inputs) + try container.encode(inputs.mapValues { $0.value }, forKey: .inputs) try container.encode(history, forKey: .history) } } diff --git a/FirebaseAI/Sources/TemplateGenerativeModel.swift b/FirebaseAI/Sources/TemplateGenerativeModel.swift index bf727021c0f..a0c15a5eb5e 100644 --- a/FirebaseAI/Sources/TemplateGenerativeModel.swift +++ b/FirebaseAI/Sources/TemplateGenerativeModel.swift @@ -40,14 +40,13 @@ public final class TemplateGenerativeModel: Sendable { /// - Returns: The content generated by the model. /// - Throws: A ``GenerateContentError`` if the request failed. public func generateContent(templateID: String, - inputs: [String: Any], + inputs: [String: any TemplateInputRepresentable], options: RequestOptions = RequestOptions()) async throws -> GenerateContentResponse { - let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } return try await generateContentWithHistory( history: [], template: templateID, - inputs: templateInputs, + inputs: inputs.mapValues { $0.templateInputRepresentation }, options: options ) } @@ -90,13 +89,12 @@ public final class TemplateGenerativeModel: Sendable { /// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects. /// - Throws: A ``GenerateContentError`` if the request failed. public func generateContentStream(templateID: String, - inputs: [String: Any], + inputs: [String: any TemplateInputRepresentable], options: RequestOptions = RequestOptions()) throws -> AsyncThrowingStream { - let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } let request = TemplateGenerateContentRequest( template: templateID, - inputs: templateInputs, + inputs: inputs.mapValues { $0.templateInputRepresentation }, history: [], projectID: generativeAIService.firebaseInfo.projectID, stream: true, diff --git a/FirebaseAI/Sources/TemplateImagenGenerationRequest.swift b/FirebaseAI/Sources/TemplateImagenGenerationRequest.swift index c155b66fe55..7693fcba67d 100644 --- a/FirebaseAI/Sources/TemplateImagenGenerationRequest.swift +++ b/FirebaseAI/Sources/TemplateImagenGenerationRequest.swift @@ -62,6 +62,6 @@ extension TemplateImagenGenerationRequest: Encodable { func encode(to encoder: Encoder) throws { var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode(inputs, forKey: .inputs) + try container.encode(inputs.mapValues { $0.value }, forKey: .inputs) } } diff --git a/FirebaseAI/Sources/TemplateImagenModel.swift b/FirebaseAI/Sources/TemplateImagenModel.swift index 794965364bd..c90dc64bc1c 100644 --- a/FirebaseAI/Sources/TemplateImagenModel.swift +++ b/FirebaseAI/Sources/TemplateImagenModel.swift @@ -39,14 +39,13 @@ public final class TemplateImagenModel: Sendable { /// - Returns: The images generated by the model. /// - Throws: An error if the request failed. public func generateImages(templateID: String, - inputs: [String: Any], + inputs: [String: any TemplateInputRepresentable], options: RequestOptions = RequestOptions()) async throws -> ImagenGenerationResponse { - let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } let projectID = generativeAIService.firebaseInfo.projectID let request = TemplateImagenGenerationRequest( template: templateID, - inputs: templateInputs, + inputs: inputs.mapValues { $0.templateInputRepresentation }, projectID: projectID, apiConfig: apiConfig, options: options diff --git a/FirebaseAI/Sources/TemplateInput.swift b/FirebaseAI/Sources/TemplateInput.swift index 606150ed824..e3c7360ffd9 100644 --- a/FirebaseAI/Sources/TemplateInput.swift +++ b/FirebaseAI/Sources/TemplateInput.swift @@ -14,53 +14,20 @@ import Foundation -enum TemplateInput: Encodable, Sendable { - case string(String) - case int(Int) - case double(Double) - case bool(Bool) - case array([TemplateInput]) - case dictionary([String: TemplateInput]) +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct TemplateInput: Sendable { + let value: JSONValue - init(value: Any) throws { - switch value { - case let value as String: - self = .string(value) - case let value as Int: - self = .int(value) - case let value as Double: - self = .double(value) - case let value as Float: - self = .double(Double(value)) - case let value as Bool: - self = .bool(value) - case let value as [Any]: - self = try .array(value.map { try TemplateInput(value: $0) }) - case let value as [String: Any]: - self = try .dictionary(value.mapValues { try TemplateInput(value: $0) }) - default: - throw EncodingError.invalidValue( - value, - EncodingError.Context(codingPath: [], debugDescription: "Invalid value") - ) - } + public init(_ input: some TemplateInputRepresentable) { + self = .init(value: input.templateInputRepresentation.value) } - func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case let .string(value): - try container.encode(value) - case let .int(value): - try container.encode(value) - case let .double(value): - try container.encode(value) - case let .bool(value): - try container.encode(value) - case let .array(value): - try container.encode(value) - case let .dictionary(value): - try container.encode(value) - } + init(value: JSONValue) { + self.value = value } } + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension TemplateInput: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { self } +} diff --git a/FirebaseAI/Sources/Types/Public/TemplateInputRepresentable.swift b/FirebaseAI/Sources/Types/Public/TemplateInputRepresentable.swift new file mode 100644 index 00000000000..28ab413d23c --- /dev/null +++ b/FirebaseAI/Sources/Types/Public/TemplateInputRepresentable.swift @@ -0,0 +1,65 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type that can be represented as a ``TemplateInput``. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public protocol TemplateInputRepresentable: Encodable, Sendable { + var templateInputRepresentation: TemplateInput { get } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension String: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { TemplateInput(value: .string(self)) } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension Int: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { + TemplateInput(value: .number(Double(self))) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension Double: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { TemplateInput(value: .number(self)) } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension Float: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { + TemplateInput(value: .number(Double(self))) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension Bool: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { TemplateInput(value: .bool(self)) } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension Array: TemplateInputRepresentable where Element: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { + TemplateInput(value: .array(map { TemplateInput($0).value })) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension Dictionary: TemplateInputRepresentable + where Key == String, Value: TemplateInputRepresentable { + public var templateInputRepresentation: TemplateInput { + TemplateInput(value: .object(mapValues { TemplateInput($0).value })) + } +} diff --git a/FirebaseAI/Tests/Unit/TemplateInputTests.swift b/FirebaseAI/Tests/Unit/TemplateInputTests.swift index 2ed428be12b..55dd87dd696 100644 --- a/FirebaseAI/Tests/Unit/TemplateInputTests.swift +++ b/FirebaseAI/Tests/Unit/TemplateInputTests.swift @@ -19,8 +19,8 @@ import XCTest final class TemplateInputTests: XCTestCase { func testInitWithFloat() throws { let floatValue: Float = 3.14 - let templateInput = try TemplateInput(value: floatValue) - guard case let .double(doubleValue) = templateInput else { + let templateInput = TemplateInput(floatValue) + guard case let .number(doubleValue) = templateInput.value else { XCTFail("Expected a .double case, but got \(templateInput)") return }