Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions FirebaseAI/Sources/AILog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum AILog {
case generativeModelInitialized = 1000
case unsupportedGeminiModel = 1001
case invalidSchemaFormat = 1002
case unsupportedConfig = 1003

// Imagen Model Configuration
case unsupportedImagenModel = 1200
Expand Down
32 changes: 22 additions & 10 deletions FirebaseAI/Sources/FirebaseAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ public final class FirebaseAI: Sendable {
)
// Verify that the `FirebaseAI` instance is always configured with the production endpoint since
// this is the public API surface for creating an instance.
assert(instance.apiConfig.service.endpoint == .firebaseProxyProd)
assert(instance.apiConfig.version == .v1beta)
if case let .cloud(config) = instance.apiConfig {
assert(config.service.endpoint == .firebaseProxyProd)
assert(config.version == .v1beta)
}
return instance
}

Expand Down Expand Up @@ -86,8 +88,9 @@ public final class FirebaseAI: Sendable {
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions = RequestOptions())
-> GenerativeModel {
if !modelName.starts(with: GenerativeModel.geminiModelNamePrefix)
&& !modelName.starts(with: GenerativeModel.gemmaModelNamePrefix) {
if case .cloud = apiConfig,
!modelName.starts(with: GenerativeModel.geminiModelNamePrefix)
&& !modelName.starts(with: GenerativeModel.gemmaModelNamePrefix) {
AILog.warning(code: .unsupportedGeminiModel, """
Unsupported Gemini model "\(modelName)"; see \
https://firebase.google.com/docs/vertex-ai/models for a list supported Gemini model names.
Expand Down Expand Up @@ -281,11 +284,16 @@ public final class FirebaseAI: Sendable {
""")
}

switch apiConfig.service {
case let .vertexAI(endpoint: _, location: location):
return vertexAIModelResourceName(modelName: modelName, location: location)
case .googleAI:
return developerModelResourceName(modelName: modelName)
switch apiConfig {
case let .cloud(config):
switch config.service {
case let .vertexAI(endpoint: _, location: location):
return vertexAIModelResourceName(modelName: modelName, location: location)
case .googleAI:
return developerModelResourceName(modelName: modelName)
}
case .onDevice:
return modelName
}
}

Expand All @@ -304,7 +312,11 @@ public final class FirebaseAI: Sendable {
}

private func developerModelResourceName(modelName: String) -> String {
switch apiConfig.service.endpoint {
guard case let .cloud(config) = apiConfig else {
fatalError("Developer API is only supported for cloud configurations.")
}

switch config.service.endpoint {
case .firebaseProxyProd:
return "projects/\(firebaseInfo.projectID)/models/\(modelName)"
#if DEBUG
Expand Down
8 changes: 7 additions & 1 deletion FirebaseAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

func getURL() throws -> URL {
let modelURL = "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)"
guard case let .cloud(config) = apiConfig else {
throw AILog.makeInternalError(
message: "URL generation not supported for on-device models",
code: .unsupportedConfig
)
}
let modelURL = "\(config.service.endpoint.rawValue)/\(config.version.rawValue)/\(model)"
let urlString: String
switch apiMethod {
case .generateContent:
Expand Down
123 changes: 99 additions & 24 deletions FirebaseAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,51 @@ public final class GenerativeModel: Sendable {
generationConfig: GenerationConfig?) async throws
-> GenerateContentResponse {
try content.throwIfError()

if case .onDevice = apiConfig {
#if canImport(FoundationModels)
if #available(iOS 26.0, macOS 26.0, *) {
var prompt = content.map { modelContent in
modelContent.parts.compactMap { ($0 as? TextPart)?.text }.joined()
}.joined(separator: "\n")

// Inject JSON Schema if present (On-Device "JSON Mode" polyfill)
if let schema = generationConfig?.responseJSONSchema,
let schemaData = try? JSONSerialization.data(
withJSONObject: schema,
options: [.prettyPrinted, .sortedKeys]
),
let schemaString = String(data: schemaData, encoding: .utf8) {
prompt += "\n\nReturn a valid JSON object matching this schema:\n\(schemaString)"
}

let instructions = systemInstruction?.parts.compactMap { ($0 as? TextPart)?.text }
.joined()

// TODO: Map `tools` to FoundationModels tools when supported.
let session = LanguageModelSession(model: .default, instructions: instructions)
let response = try await session.respond(to: prompt)

return GenerateContentResponse(candidates: [Candidate(
content: ModelContent(role: "model", parts: [TextPart(response.content)]),
safetyRatings: [],
finishReason: .stop,
citationMetadata: nil
)], promptFeedback: nil, usageMetadata: nil)
} else {
throw AILog.makeInternalError(
message: "Foundation Models require iOS 26.0+",
code: .unsupportedConfig
)
}
#else
throw AILog.makeInternalError(
message: "Foundation Models not available",
code: .unsupportedConfig
)
#endif
}

let response: GenerateContentResponse
let generateContentRequest = GenerateContentRequest(
model: modelResourceName,
Expand Down Expand Up @@ -296,6 +341,21 @@ public final class GenerativeModel: Sendable {
generationConfig: GenerationConfig?) throws
-> AsyncThrowingStream<GenerateContentResponse, Error> {
try content.throwIfError()

if case .onDevice = apiConfig {
return AsyncThrowingStream { continuation in
Task {
do {
let response = try await generateContent(content, generationConfig: generationConfig)
continuation.yield(response)
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
}
}

let generateContentRequest = GenerateContentRequest(
model: modelResourceName,
contents: content,
Expand Down Expand Up @@ -389,36 +449,51 @@ public final class GenerativeModel: Sendable {
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse {
let requestContent = switch apiConfig.service {
case .vertexAI:
content
case .googleAI:
// The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
// erroneously counted towards the prompt and total token count when using the Developer API
// backend; set to `nil` to avoid token count discrepancies between `countTokens` and
// `generateContent` and the two backend APIs.
content.map { ModelContent(role: nil, parts: $0.parts) }
let requestContent: [ModelContent]
switch apiConfig {
case let .cloud(config):
switch config.service {
case .vertexAI:
requestContent = content
case .googleAI:
// The `role` defaults to "user" but is ignored in `countTokens`. However, it is erroneously
// erroneously counted towards the prompt and total token count when using the Developer API
// backend; set to `nil` to avoid token count discrepancies between `countTokens` and
// `generateContent` and the two backend APIs.
requestContent = content.map { ModelContent(role: nil, parts: $0.parts) }
}
case .onDevice:
throw AILog.makeInternalError(
message: "countTokens() is not yet supported for on-device models.",
code: .unsupportedConfig
)
}

// When using the Developer API via the Firebase backend, the model name of the
// `GenerateContentRequest` nested in the `CountTokensRequest` must be of the form
// "models/model-name". This field is unaltered by the Firebase backend before forwarding the
// request to the Generative Language backend, which expects the form "models/model-name".
let generateContentRequestModelResourceName = switch apiConfig.service {
case .vertexAI:
modelResourceName
case .googleAI(endpoint: .firebaseProxyProd):
"models/\(modelName)"
#if DEBUG
case .googleAI(endpoint: .firebaseProxyStaging):
"models/\(modelName)"
case .googleAI(endpoint: .googleAIBypassProxy):
modelResourceName
case .googleAI(endpoint: .vertexAIStagingBypassProxy):
fatalError(
"The Vertex AI staging endpoint does not support the Gemini Developer API (Google AI)."
)
#endif // DEBUG
let generateContentRequestModelResourceName: String
switch apiConfig {
case let .cloud(config):
switch config.service {
case .vertexAI:
generateContentRequestModelResourceName = modelResourceName
case .googleAI(endpoint: .firebaseProxyProd):
generateContentRequestModelResourceName = "models/\(modelName)"
#if DEBUG
case .googleAI(endpoint: .firebaseProxyStaging):
generateContentRequestModelResourceName = "models/\(modelName)"
case .googleAI(endpoint: .googleAIBypassProxy):
generateContentRequestModelResourceName = modelResourceName
case .googleAI(endpoint: .vertexAIStagingBypassProxy):
fatalError(
"The Vertex AI staging endpoint does not support the Gemini Developer API (Google AI)."
)
#endif // DEBUG
}
case .onDevice:
generateContentRequestModelResourceName = modelResourceName
}

let generateContentRequest = GenerateContentRequest(
Expand Down
11 changes: 9 additions & 2 deletions FirebaseAI/Sources/TemplateGenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@ extension TemplateGenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

func getURL() throws -> URL {
guard case let .cloud(config) = apiConfig else {
throw AILog.makeInternalError(
message: "Templates not supported on-device",
code: .unsupportedConfig
)
}

var urlString =
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/projects/\(projectID)"
if case let .vertexAI(_, location) = apiConfig.service {
"\(config.service.endpoint.rawValue)/\(config.version.rawValue)/projects/\(projectID)"
if case let .vertexAI(_, location) = config.service {
urlString += "/locations/\(location)"
}

Expand Down
11 changes: 9 additions & 2 deletions FirebaseAI/Sources/TemplateImagenGenerationRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ struct TemplateImagenGenerationRequest<ImageType: ImagenImageRepresentable>: Sen
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension TemplateImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
func getURL() throws -> URL {
guard case let .cloud(config) = apiConfig else {
throw AILog.makeInternalError(
message: "Templates not supported on-device",
code: .unsupportedConfig
)
}

var urlString =
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/projects/\(projectID)"
if case let .vertexAI(_, location) = apiConfig.service {
"\(config.service.endpoint.rawValue)/\(config.version.rawValue)/projects/\(projectID)"
if case let .vertexAI(_, location) = config.service {
urlString += "/locations/\(location)"
}
urlString += "/templates/\(template):\(ImageAPIMethod.generateImages.rawValue)"
Expand Down
23 changes: 16 additions & 7 deletions FirebaseAI/Sources/Types/Internal/APIConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

/// Configuration for the generative AI backend API used by this SDK.
struct APIConfig: Sendable, Hashable, Encodable {
/// Configuration for the generative AI backend.
enum APIConfig: Sendable, Hashable, Encodable {
/// Configuration for cloud-based inference (Vertex AI, Google AI).
case cloud(CloudConfig)

/// Configuration for on-device inference (Apple Foundation Models).
case onDevice
}

/// Configuration for cloud-based backend APIs.
struct CloudConfig: Sendable, Hashable, Encodable {
/// The service to use for generative AI.
///
/// This controls which backend API is used by the SDK.
Expand All @@ -22,7 +31,7 @@ struct APIConfig: Sendable, Hashable, Encodable {
/// The version of the selected API to use, e.g., "v1".
let version: Version

/// Initializes an API configuration.
/// Initializes a cloud configuration.
///
/// - Parameters:
/// - service: The API service to use for generative AI.
Expand All @@ -33,7 +42,7 @@ struct APIConfig: Sendable, Hashable, Encodable {
}
}

extension APIConfig {
extension CloudConfig {
/// API services providing generative AI functionality.
///
/// See [Vertex AI and Google AI
Expand Down Expand Up @@ -66,7 +75,7 @@ extension APIConfig {
}
}

extension APIConfig.Service {
extension CloudConfig.Service {
/// Network addresses for generative AI API services.
// TODO: maybe remove the https:// prefix and just add it as needed? websockets use these too.
enum Endpoint: String, Encodable {
Expand Down Expand Up @@ -98,8 +107,8 @@ extension APIConfig.Service {
}
}

extension APIConfig {
/// Versions of the configured API service (`APIConfig.Service`).
extension CloudConfig {
/// Versions of the configured API service (`CloudConfig.Service`).
enum Version: String, Encodable {
/// The beta channel for version 1 of the API.
case v1beta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodabl
typealias Response = ImagenGenerationResponse<ImageType>

func getURL() throws -> URL {
guard case let .cloud(config) = apiConfig else {
throw AILog.makeInternalError(
message: "Imagen not supported on-device",
code: .unsupportedConfig
)
}
let urlString =
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict"
"\(config.service.endpoint.rawValue)/\(config.version.rawValue)/\(model):predict"
guard let url = URL(string: urlString) else {
throw AILog.makeInternalError(message: "Malformed URL: \(urlString)", code: .malformedURL)
}
Expand Down
23 changes: 19 additions & 4 deletions FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,27 @@ actor LiveSessionService {
///
/// Will apply the required app check and auth headers, as the backend expects them.
private nonisolated func createWebsocket() async throws -> AsyncWebSocket {
let host = apiConfig.service.endpoint.rawValue.withoutPrefix("https://")
let urlString = switch apiConfig.service {
guard case let .cloud(config) = apiConfig else {
let message = "The Live API is not supported for on-device foundation models."
AILog.error(code: .unsupportedConfig, message)
throw NSError(
domain: "\(Constants.baseErrorDomain).\(Self.self)",
code: AILog.MessageCode.unsupportedConfig.rawValue,
userInfo: [
NSLocalizedDescriptionKey: message,
]
)
}

let host = config.service.endpoint.rawValue.withoutPrefix("https://")
let urlString: String
switch config.service {
case let .vertexAI(_, location: location):
"wss://\(host)/ws/google.firebase.vertexai.\(apiConfig.version.rawValue).LlmBidiService/BidiGenerateContent/locations/\(location)"
urlString =
"wss://\(host)/ws/google.firebase.vertexai.\(config.version.rawValue).LlmBidiService/BidiGenerateContent/locations/\(location)"
case .googleAI:
"wss://\(host)/ws/google.firebase.vertexai.\(apiConfig.version.rawValue).GenerativeService/BidiGenerateContent"
urlString =
"wss://\(host)/ws/google.firebase.vertexai.\(config.version.rawValue).GenerativeService/BidiGenerateContent"
}
guard let url = URL(string: urlString) else {
throw NSError(
Expand Down
Loading
Loading