diff --git a/FirebaseAI/Sources/FirebaseAI.swift b/FirebaseAI/Sources/FirebaseAI.swift index fdd870ecfcf..f9ff5ea0424 100644 --- a/FirebaseAI/Sources/FirebaseAI.swift +++ b/FirebaseAI/Sources/FirebaseAI.swift @@ -47,7 +47,6 @@ public final class FirebaseAI: Sendable { useLimitedUseAppCheckTokens: Bool = false) -> FirebaseAI { let instance = createInstance( app: app, - location: backend.location, apiConfig: backend.apiConfig, useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens ) @@ -188,21 +187,14 @@ public final class FirebaseAI: Sendable { let apiConfig: APIConfig - /// A map of active `FirebaseAI` instances keyed by the `FirebaseApp` name and the `location`, - /// in the format `appName:location`. + /// A map of active `FirebaseAI` instances keyed by the `FirebaseApp`, the `APIConfig`, and + /// `useLimitedUseAppCheckTokens`. private nonisolated(unsafe) static var instances: [InstanceKey: FirebaseAI] = [:] /// Lock to manage access to the `instances` array to avoid race conditions. private nonisolated(unsafe) static var instancesLock: os_unfair_lock = .init() - let location: String? - - static let defaultVertexAIAPIConfig = APIConfig( - service: .vertexAI(endpoint: .firebaseProxyProd), - version: .v1beta - ) - - static func createInstance(app: FirebaseApp?, location: String?, + static func createInstance(app: FirebaseApp?, apiConfig: APIConfig, useLimitedUseAppCheckTokens: Bool) -> FirebaseAI { guard let app = app ?? FirebaseApp.app() else { @@ -216,7 +208,6 @@ public final class FirebaseAI: Sendable { let instanceKey = InstanceKey( appName: app.name, - location: location, apiConfig: apiConfig, useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens ) @@ -225,7 +216,6 @@ public final class FirebaseAI: Sendable { } let newInstance = FirebaseAI( app: app, - location: location, apiConfig: apiConfig, useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens ) @@ -233,7 +223,7 @@ public final class FirebaseAI: Sendable { return newInstance } - init(app: FirebaseApp, location: String?, apiConfig: APIConfig, + init(app: FirebaseApp, apiConfig: APIConfig, useLimitedUseAppCheckTokens: Bool) { guard let projectID = app.options.projectID else { fatalError("The Firebase app named \"\(app.name)\" has no project ID in its configuration.") @@ -254,7 +244,6 @@ public final class FirebaseAI: Sendable { useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens ) self.apiConfig = apiConfig - self.location = location } func modelResourceName(modelName: String) -> String { @@ -268,17 +257,14 @@ public final class FirebaseAI: Sendable { } switch apiConfig.service { - case .vertexAI: - return vertexAIModelResourceName(modelName: modelName) + case let .vertexAI(endpoint: _, location: location): + return vertexAIModelResourceName(modelName: modelName, location: location) case .googleAI: return developerModelResourceName(modelName: modelName) } } - private func vertexAIModelResourceName(modelName: String) -> String { - guard let location else { - fatalError("Location must be specified for the Firebase AI service.") - } + private func vertexAIModelResourceName(modelName: String, location: String) -> String { guard !location.isEmpty && location .allSatisfy({ !$0.isWhitespace && !$0.isNewline && $0 != "/" }) else { fatalError(""" @@ -307,7 +293,6 @@ public final class FirebaseAI: Sendable { /// This type is `Hashable` so that it can be used as a key in the `instances` dictionary. private struct InstanceKey: Sendable, Hashable { let appName: String - let location: String? let apiConfig: APIConfig let useLimitedUseAppCheckTokens: Bool } diff --git a/FirebaseAI/Sources/Types/Internal/APIConfig.swift b/FirebaseAI/Sources/Types/Internal/APIConfig.swift index e854db25c8c..97a8615e98a 100644 --- a/FirebaseAI/Sources/Types/Internal/APIConfig.swift +++ b/FirebaseAI/Sources/Types/Internal/APIConfig.swift @@ -45,7 +45,7 @@ extension APIConfig { /// See the [Cloud /// docs](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference) for /// more details. - case vertexAI(endpoint: Endpoint) + case vertexAI(endpoint: Endpoint, location: String) /// The Gemini Developer API provided by Google AI. /// @@ -57,7 +57,7 @@ extension APIConfig { /// This must correspond with the API set in `service`. var endpoint: Endpoint { switch self { - case let .vertexAI(endpoint: endpoint): + case let .vertexAI(endpoint: endpoint, _): return endpoint case let .googleAI(endpoint: endpoint): return endpoint diff --git a/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift b/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift index 42f8364b90f..a49e34e6671 100644 --- a/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift +++ b/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift @@ -309,12 +309,11 @@ actor LiveSessionService { /// Will apply the required app check and auth headers, as the backend expects them. private nonisolated func createWebsocket() async throws -> AsyncWebSocket { let host = apiConfig.service.endpoint.rawValue.withoutPrefix("https://") - // TODO: (b/448722577) Set a location based on the api config let urlString = switch apiConfig.service { - case .vertexAI: - "wss://\(host)/ws/google.firebase.vertexai.v1beta.LlmBidiService/BidiGenerateContent/locations/us-central1" + case let .vertexAI(_, location: location): + "wss://\(host)/ws/google.firebase.vertexai.\(apiConfig.version.rawValue).LlmBidiService/BidiGenerateContent/locations/\(location)" case .googleAI: - "wss://\(host)/ws/google.firebase.vertexai.v1beta.GenerativeService/BidiGenerateContent" + "wss://\(host)/ws/google.firebase.vertexai.\(apiConfig.version.rawValue).GenerativeService/BidiGenerateContent" } guard let url = URL(string: urlString) else { throw NSError( diff --git a/FirebaseAI/Sources/Types/Public/Backend.swift b/FirebaseAI/Sources/Types/Public/Backend.swift index 132f3a2cd72..b4b55699494 100644 --- a/FirebaseAI/Sources/Types/Public/Backend.swift +++ b/FirebaseAI/Sources/Types/Public/Backend.swift @@ -25,26 +25,28 @@ public struct Backend { /// for a list of supported locations. public static func vertexAI(location: String = "us-central1") -> Backend { return Backend( - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta), - location: location + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: location), + version: .v1beta + ) ) } /// Initializes a `Backend` configured for the Google Developer API. public static func googleAI() -> Backend { return Backend( - apiConfig: APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta), - location: nil + apiConfig: APIConfig( + service: .googleAI(endpoint: .firebaseProxyProd), + version: .v1beta + ) ) } // MARK: - Internal let apiConfig: APIConfig - let location: String? - init(apiConfig: APIConfig, location: String?) { + init(apiConfig: APIConfig) { self.apiConfig = apiConfig - self.location = location } } diff --git a/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift b/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift index df06f43c91f..bf9d32c6e0d 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift @@ -21,19 +21,29 @@ import Testing struct InstanceConfig: Equatable, Encodable { static let vertexAI_v1beta = InstanceConfig( - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta) + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"), + version: .v1beta + ) ) static let vertexAI_v1beta_global = InstanceConfig( - location: "global", - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta) + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "global"), + version: .v1beta + ) ) static let vertexAI_v1beta_global_appCheckLimitedUse = InstanceConfig( - location: "global", useLimitedUseAppCheckTokens: true, - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta) + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "global"), + version: .v1beta + ) ) static let vertexAI_v1beta_staging = InstanceConfig( - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyStaging), version: .v1beta) + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyStaging, location: "us-central1"), + version: .v1beta + ) ) static let googleAI_v1beta = InstanceConfig( apiConfig: APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta) @@ -68,12 +78,18 @@ struct InstanceConfig: Equatable, Encodable { static let vertexAI_v1beta_appCheckNotConfigured = InstanceConfig( appName: FirebaseAppNames.appCheckNotConfigured, - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta) + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"), + version: .v1beta + ) ) static let vertexAI_v1beta_appCheckNotConfigured_limitedUseTokens = InstanceConfig( appName: FirebaseAppNames.appCheckNotConfigured, useLimitedUseAppCheckTokens: true, - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta) + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"), + version: .v1beta + ) ) static let googleAI_v1beta_appCheckNotConfigured = InstanceConfig( appName: FirebaseAppNames.appCheckNotConfigured, @@ -93,16 +109,11 @@ struct InstanceConfig: Equatable, Encodable { ] let appName: String? - let location: String? let useLimitedUseAppCheckTokens: Bool let apiConfig: APIConfig - init(appName: String? = nil, - location: String? = nil, - useLimitedUseAppCheckTokens: Bool = false, - apiConfig: APIConfig) { + init(appName: String? = nil, useLimitedUseAppCheckTokens: Bool = false, apiConfig: APIConfig) { self.appName = appName - self.location = location self.useLimitedUseAppCheckTokens = useLimitedUseAppCheckTokens self.apiConfig = apiConfig } @@ -136,7 +147,12 @@ extension InstanceConfig: CustomTestStringConvertible { case .googleAIBypassProxy: " - Bypass Proxy" } - let locationSuffix = location.map { " - \($0)" } ?? "" + let locationSuffix: String + if case let .vertexAI(_, location: location) = apiConfig.service { + locationSuffix = location + } else { + locationSuffix = "" + } let appCheckLimitedUseDesignator = useLimitedUseAppCheckTokens ? " - FAC Limited-Use" : "" return """ @@ -150,21 +166,14 @@ extension FirebaseAI { static func componentInstance(_ instanceConfig: InstanceConfig) -> FirebaseAI { switch instanceConfig.apiConfig.service { case .vertexAI: - let location = instanceConfig.location ?? "us-central1" return FirebaseAI.createInstance( app: instanceConfig.app, - location: location, apiConfig: instanceConfig.apiConfig, useLimitedUseAppCheckTokens: instanceConfig.useLimitedUseAppCheckTokens ) case .googleAI: - assert( - instanceConfig.location == nil, - "The Developer API is global and does not support `location`." - ) return FirebaseAI.createInstance( app: instanceConfig.app, - location: nil, apiConfig: instanceConfig.apiConfig, useLimitedUseAppCheckTokens: instanceConfig.useLimitedUseAppCheckTokens ) diff --git a/FirebaseAI/Tests/Unit/TestUtilities/FirebaseAI+DefaultAPIConfig.swift b/FirebaseAI/Tests/Unit/TestUtilities/FirebaseAI+DefaultAPIConfig.swift new file mode 100644 index 00000000000..48a596f01a4 --- /dev/null +++ b/FirebaseAI/Tests/Unit/TestUtilities/FirebaseAI+DefaultAPIConfig.swift @@ -0,0 +1,22 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAI + +extension FirebaseAI { + static let defaultVertexAIAPIConfig = APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"), + version: .v1beta + ) +} diff --git a/FirebaseAI/Tests/Unit/Types/BackendTests.swift b/FirebaseAI/Tests/Unit/Types/BackendTests.swift index e4e87784e68..d0fe40c7cbc 100644 --- a/FirebaseAI/Tests/Unit/Types/BackendTests.swift +++ b/FirebaseAI/Tests/Unit/Types/BackendTests.swift @@ -19,27 +19,25 @@ import XCTest final class BackendTests: XCTestCase { func testVertexAI_defaultLocation() { let expectedAPIConfig = APIConfig( - service: .vertexAI(endpoint: .firebaseProxyProd), + service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"), version: .v1beta ) let backend = Backend.vertexAI() XCTAssertEqual(backend.apiConfig, expectedAPIConfig) - XCTAssertEqual(backend.location, "us-central1") } func testVertexAI_customLocation() { + let customLocation = "europe-west1" let expectedAPIConfig = APIConfig( - service: .vertexAI(endpoint: .firebaseProxyProd), + service: .vertexAI(endpoint: .firebaseProxyProd, location: customLocation), version: .v1beta ) - let customLocation = "europe-west1" let backend = Backend.vertexAI(location: customLocation) XCTAssertEqual(backend.apiConfig, expectedAPIConfig) - XCTAssertEqual(backend.location, customLocation) } func testGoogleAI() { @@ -51,6 +49,5 @@ final class BackendTests: XCTestCase { let backend = Backend.googleAI() XCTAssertEqual(backend.apiConfig, expectedAPIConfig) - XCTAssertNil(backend.location) } } diff --git a/FirebaseAI/Tests/Unit/Types/Internal/APIConfigTests.swift b/FirebaseAI/Tests/Unit/Types/Internal/APIConfigTests.swift index fe4c290831a..937b858d40b 100644 --- a/FirebaseAI/Tests/Unit/Types/Internal/APIConfigTests.swift +++ b/FirebaseAI/Tests/Unit/Types/Internal/APIConfigTests.swift @@ -18,38 +18,70 @@ import XCTest @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) final class APIConfigTests: XCTestCase { + let defaultLocation = "us-central1" + let globalLocation = "global" + func testInitialize_vertexAI_prod_v1() { - let apiConfig = APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1) + let apiConfig = APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: defaultLocation), + version: .v1 + ) - XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://firebasevertexai.googleapis.com") + switch apiConfig.service { + case let .vertexAI(endpoint: endpoint, location: location): + XCTAssertEqual(endpoint.rawValue, "https://firebasevertexai.googleapis.com") + XCTAssertEqual(location, defaultLocation) + case .googleAI: + XCTFail("Expected .vertexAI, got .googleAI") + } XCTAssertEqual(apiConfig.version.rawValue, "v1") } func testInitialize_vertexAI_prod_v1beta() { - let apiConfig = APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta) + let apiConfig = APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: defaultLocation), + version: .v1beta + ) - XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://firebasevertexai.googleapis.com") + switch apiConfig.service { + case let .vertexAI(endpoint: endpoint, location: location): + XCTAssertEqual(endpoint.rawValue, "https://firebasevertexai.googleapis.com") + XCTAssertEqual(location, defaultLocation) + case .googleAI: + XCTFail("Expected .vertexAI, got .googleAI") + } XCTAssertEqual(apiConfig.version.rawValue, "v1beta") } func testInitialize_vertexAI_staging_v1() { - let apiConfig = APIConfig(service: .vertexAI(endpoint: .firebaseProxyStaging), version: .v1) - - XCTAssertEqual( - apiConfig.service.endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com" + let apiConfig = APIConfig( + service: .vertexAI(endpoint: .firebaseProxyStaging, location: defaultLocation), + version: .v1 ) + + switch apiConfig.service { + case let .vertexAI(endpoint: endpoint, location: location): + XCTAssertEqual(endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com") + XCTAssertEqual(location, defaultLocation) + case .googleAI: + XCTFail("Expected .vertexAI, got .googleAI") + } XCTAssertEqual(apiConfig.version.rawValue, "v1") } func testInitialize_vertexAI_staging_v1beta() { let apiConfig = APIConfig( - service: .vertexAI(endpoint: .firebaseProxyStaging), + service: .vertexAI(endpoint: .firebaseProxyStaging, location: defaultLocation), version: .v1beta ) - XCTAssertEqual( - apiConfig.service.endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com" - ) + switch apiConfig.service { + case let .vertexAI(endpoint: endpoint, location: location): + XCTAssertEqual(endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com") + XCTAssertEqual(location, defaultLocation) + case .googleAI: + XCTFail("Expected .vertexAI, got .googleAI") + } XCTAssertEqual(apiConfig.version.rawValue, "v1beta") } @@ -58,16 +90,24 @@ final class APIConfigTests: XCTestCase { service: .googleAI(endpoint: .firebaseProxyStaging), version: .v1beta ) - XCTAssertEqual( - apiConfig.service.endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com" - ) + switch apiConfig.service { + case .vertexAI: + XCTFail("Expected .googleAI, got .vertexAI") + case let .googleAI(endpoint: endpoint): + XCTAssertEqual(endpoint.rawValue, "https://staging-firebasevertexai.sandbox.googleapis.com") + } XCTAssertEqual(apiConfig.version.rawValue, "v1beta") } func testInitialize_developer_generativeLanguage_v1beta() { let apiConfig = APIConfig(service: .googleAI(endpoint: .googleAIBypassProxy), version: .v1beta) - XCTAssertEqual(apiConfig.service.endpoint.rawValue, "https://generativelanguage.googleapis.com") + switch apiConfig.service { + case .vertexAI: + XCTFail("Expected .googleAI, got .vertexAI") + case let .googleAI(endpoint: endpoint): + XCTAssertEqual(endpoint.rawValue, "https://generativelanguage.googleapis.com") + } XCTAssertEqual(apiConfig.version.rawValue, "v1beta") } } diff --git a/FirebaseAI/Tests/Unit/VertexComponentTests.swift b/FirebaseAI/Tests/Unit/VertexComponentTests.swift index 702c6e50871..66b3ae68576 100644 --- a/FirebaseAI/Tests/Unit/VertexComponentTests.swift +++ b/FirebaseAI/Tests/Unit/VertexComponentTests.swift @@ -57,8 +57,9 @@ class VertexComponentTests: XCTestCase { XCTAssertNotNil(vertex) XCTAssertEqual(vertex.firebaseInfo.projectID, VertexComponentTests.projectID) XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey) - XCTAssertEqual(vertex.location, "us-central1") - XCTAssertEqual(vertex.apiConfig.service, .vertexAI(endpoint: .firebaseProxyProd)) + XCTAssertEqual( + vertex.apiConfig.service, .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1") + ) XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseProxyProd) XCTAssertEqual(vertex.apiConfig.version, .v1beta) } @@ -71,8 +72,9 @@ class VertexComponentTests: XCTestCase { XCTAssertNotNil(vertex) XCTAssertEqual(vertex.firebaseInfo.projectID, VertexComponentTests.projectID) XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey) - XCTAssertEqual(vertex.location, location) - XCTAssertEqual(vertex.apiConfig.service, .vertexAI(endpoint: .firebaseProxyProd)) + XCTAssertEqual( + vertex.apiConfig.service, .vertexAI(endpoint: .firebaseProxyProd, location: location) + ) XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseProxyProd) XCTAssertEqual(vertex.apiConfig.version, .v1beta) } @@ -87,8 +89,9 @@ class VertexComponentTests: XCTestCase { XCTAssertNotNil(vertex) XCTAssertEqual(vertex.firebaseInfo.projectID, VertexComponentTests.projectID) XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey) - XCTAssertEqual(vertex.location, location) - XCTAssertEqual(vertex.apiConfig.service, .vertexAI(endpoint: .firebaseProxyProd)) + XCTAssertEqual( + vertex.apiConfig.service, .vertexAI(endpoint: .firebaseProxyProd, location: location) + ) XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseProxyProd) XCTAssertEqual(vertex.apiConfig.version, .v1beta) } @@ -154,14 +157,17 @@ class VertexComponentTests: XCTestCase { func testSameAppAndDifferentAPI_newInstanceCreated() throws { let vertex1 = FirebaseAI.createInstance( app: VertexComponentTests.app, - location: location, - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta), + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: location), + version: .v1beta + ), useLimitedUseAppCheckTokens: false ) let vertex2 = FirebaseAI.createInstance( app: VertexComponentTests.app, - location: location, - apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1), + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: location), version: .v1 + ), useLimitedUseAppCheckTokens: false ) @@ -182,8 +188,10 @@ class VertexComponentTests: XCTestCase { weakApp = try XCTUnwrap(app1) let vertex = FirebaseAI( app: app1, - location: "transitory location", - apiConfig: FirebaseAI.defaultVertexAIAPIConfig, + apiConfig: APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: "transitory location"), + version: .v1beta + ), useLimitedUseAppCheckTokens: false ) weakVertex = vertex @@ -195,13 +203,13 @@ class VertexComponentTests: XCTestCase { func testModelResourceName_vertexAI() throws { let app = try XCTUnwrap(VertexComponentTests.app) + let location = "test-location" let vertex = FirebaseAI.firebaseAI(app: app, backend: .vertexAI(location: location)) let model = "test-model-name" let projectID = vertex.firebaseInfo.projectID let modelResourceName = vertex.modelResourceName(modelName: model) - let location = try XCTUnwrap(vertex.location) XCTAssertEqual( modelResourceName, "projects/\(projectID)/locations/\(location)/publishers/google/models/\(model)" @@ -212,10 +220,7 @@ class VertexComponentTests: XCTestCase { let app = try XCTUnwrap(VertexComponentTests.app) let apiConfig = APIConfig(service: .googleAI(endpoint: .googleAIBypassProxy), version: .v1beta) let vertex = FirebaseAI.createInstance( - app: app, - location: nil, - apiConfig: apiConfig, - useLimitedUseAppCheckTokens: false + app: app, apiConfig: apiConfig, useLimitedUseAppCheckTokens: false ) let model = "test-model-name" @@ -231,10 +236,7 @@ class VertexComponentTests: XCTestCase { version: .v1beta ) let vertex = FirebaseAI.createInstance( - app: app, - location: nil, - apiConfig: apiConfig, - useLimitedUseAppCheckTokens: false + app: app, apiConfig: apiConfig, useLimitedUseAppCheckTokens: false ) let model = "test-model-name" let projectID = vertex.firebaseInfo.projectID @@ -244,15 +246,14 @@ class VertexComponentTests: XCTestCase { XCTAssertEqual(modelResourceName, "projects/\(projectID)/models/\(model)") } - func testGenerativeModel_vertexAI() async throws { + func testGenerativeModel_vertexAI_defaultLocation() async throws { let app = try XCTUnwrap(VertexComponentTests.app) - let vertex = FirebaseAI.firebaseAI(app: app, backend: .vertexAI(location: location)) + let vertex = FirebaseAI.firebaseAI(app: app, backend: .vertexAI()) let modelResourceName = vertex.modelResourceName(modelName: modelName) let expectedSystemInstruction = ModelContent(role: nil, parts: systemInstruction.parts) let generativeModel = vertex.generativeModel( - modelName: modelName, - systemInstruction: systemInstruction + modelName: modelName, systemInstruction: systemInstruction ) XCTAssertEqual(generativeModel.modelResourceName, modelResourceName) @@ -260,6 +261,24 @@ class VertexComponentTests: XCTestCase { XCTAssertEqual(generativeModel.apiConfig, FirebaseAI.defaultVertexAIAPIConfig) } + func testGenerativeModel_vertexAI_customLocation() async throws { + let app = try XCTUnwrap(VertexComponentTests.app) + let vertex = FirebaseAI.firebaseAI(app: app, backend: .vertexAI(location: location)) + let modelResourceName = vertex.modelResourceName(modelName: modelName) + let expectedAPIConfig = APIConfig( + service: .vertexAI(endpoint: .firebaseProxyProd, location: location), version: .v1beta + ) + let expectedSystemInstruction = ModelContent(role: nil, parts: systemInstruction.parts) + + let generativeModel = vertex.generativeModel( + modelName: modelName, systemInstruction: systemInstruction + ) + + XCTAssertEqual(generativeModel.modelResourceName, modelResourceName) + XCTAssertEqual(generativeModel.systemInstruction, expectedSystemInstruction) + XCTAssertEqual(generativeModel.apiConfig, expectedAPIConfig) + } + func testGenerativeModel_developerAPI() async throws { let app = try XCTUnwrap(VertexComponentTests.app) let apiConfig = APIConfig( @@ -267,10 +286,7 @@ class VertexComponentTests: XCTestCase { version: .v1beta ) let vertex = FirebaseAI.createInstance( - app: app, - location: nil, - apiConfig: apiConfig, - useLimitedUseAppCheckTokens: false + app: app, apiConfig: apiConfig, useLimitedUseAppCheckTokens: false ) let modelResourceName = vertex.modelResourceName(modelName: modelName) let expectedSystemInstruction = ModelContent(role: nil, parts: systemInstruction.parts)