Skip to content

Commit 8314558

Browse files
daymxnncooke3andrewheard
authored
fix(ai): Use location in websocket endpoint (#15373)
Co-authored-by: Nick Cooke <[email protected]> Co-authored-by: Andrew Heard <[email protected]>
1 parent 7575b77 commit 8314558

File tree

9 files changed

+178
-108
lines changed

9 files changed

+178
-108
lines changed

FirebaseAI/Sources/FirebaseAI.swift

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ public final class FirebaseAI: Sendable {
4747
useLimitedUseAppCheckTokens: Bool = false) -> FirebaseAI {
4848
let instance = createInstance(
4949
app: app,
50-
location: backend.location,
5150
apiConfig: backend.apiConfig,
5251
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens
5352
)
@@ -188,21 +187,14 @@ public final class FirebaseAI: Sendable {
188187

189188
let apiConfig: APIConfig
190189

191-
/// A map of active `FirebaseAI` instances keyed by the `FirebaseApp` name and the `location`,
192-
/// in the format `appName:location`.
190+
/// A map of active `FirebaseAI` instances keyed by the `FirebaseApp`, the `APIConfig`, and
191+
/// `useLimitedUseAppCheckTokens`.
193192
private nonisolated(unsafe) static var instances: [InstanceKey: FirebaseAI] = [:]
194193

195194
/// Lock to manage access to the `instances` array to avoid race conditions.
196195
private nonisolated(unsafe) static var instancesLock: os_unfair_lock = .init()
197196

198-
let location: String?
199-
200-
static let defaultVertexAIAPIConfig = APIConfig(
201-
service: .vertexAI(endpoint: .firebaseProxyProd),
202-
version: .v1beta
203-
)
204-
205-
static func createInstance(app: FirebaseApp?, location: String?,
197+
static func createInstance(app: FirebaseApp?,
206198
apiConfig: APIConfig,
207199
useLimitedUseAppCheckTokens: Bool) -> FirebaseAI {
208200
guard let app = app ?? FirebaseApp.app() else {
@@ -216,7 +208,6 @@ public final class FirebaseAI: Sendable {
216208

217209
let instanceKey = InstanceKey(
218210
appName: app.name,
219-
location: location,
220211
apiConfig: apiConfig,
221212
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens
222213
)
@@ -225,15 +216,14 @@ public final class FirebaseAI: Sendable {
225216
}
226217
let newInstance = FirebaseAI(
227218
app: app,
228-
location: location,
229219
apiConfig: apiConfig,
230220
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens
231221
)
232222
instances[instanceKey] = newInstance
233223
return newInstance
234224
}
235225

236-
init(app: FirebaseApp, location: String?, apiConfig: APIConfig,
226+
init(app: FirebaseApp, apiConfig: APIConfig,
237227
useLimitedUseAppCheckTokens: Bool) {
238228
guard let projectID = app.options.projectID else {
239229
fatalError("The Firebase app named \"\(app.name)\" has no project ID in its configuration.")
@@ -254,7 +244,6 @@ public final class FirebaseAI: Sendable {
254244
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens
255245
)
256246
self.apiConfig = apiConfig
257-
self.location = location
258247
}
259248

260249
func modelResourceName(modelName: String) -> String {
@@ -268,17 +257,14 @@ public final class FirebaseAI: Sendable {
268257
}
269258

270259
switch apiConfig.service {
271-
case .vertexAI:
272-
return vertexAIModelResourceName(modelName: modelName)
260+
case let .vertexAI(endpoint: _, location: location):
261+
return vertexAIModelResourceName(modelName: modelName, location: location)
273262
case .googleAI:
274263
return developerModelResourceName(modelName: modelName)
275264
}
276265
}
277266

278-
private func vertexAIModelResourceName(modelName: String) -> String {
279-
guard let location else {
280-
fatalError("Location must be specified for the Firebase AI service.")
281-
}
267+
private func vertexAIModelResourceName(modelName: String, location: String) -> String {
282268
guard !location.isEmpty && location
283269
.allSatisfy({ !$0.isWhitespace && !$0.isNewline && $0 != "/" }) else {
284270
fatalError("""
@@ -307,7 +293,6 @@ public final class FirebaseAI: Sendable {
307293
/// This type is `Hashable` so that it can be used as a key in the `instances` dictionary.
308294
private struct InstanceKey: Sendable, Hashable {
309295
let appName: String
310-
let location: String?
311296
let apiConfig: APIConfig
312297
let useLimitedUseAppCheckTokens: Bool
313298
}

FirebaseAI/Sources/Types/Internal/APIConfig.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ extension APIConfig {
4545
/// See the [Cloud
4646
/// docs](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference) for
4747
/// more details.
48-
case vertexAI(endpoint: Endpoint)
48+
case vertexAI(endpoint: Endpoint, location: String)
4949

5050
/// The Gemini Developer API provided by Google AI.
5151
///
@@ -57,7 +57,7 @@ extension APIConfig {
5757
/// This must correspond with the API set in `service`.
5858
var endpoint: Endpoint {
5959
switch self {
60-
case let .vertexAI(endpoint: endpoint):
60+
case let .vertexAI(endpoint: endpoint, _):
6161
return endpoint
6262
case let .googleAI(endpoint: endpoint):
6363
return endpoint

FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,11 @@ actor LiveSessionService {
309309
/// Will apply the required app check and auth headers, as the backend expects them.
310310
private nonisolated func createWebsocket() async throws -> AsyncWebSocket {
311311
let host = apiConfig.service.endpoint.rawValue.withoutPrefix("https://")
312-
// TODO: (b/448722577) Set a location based on the api config
313312
let urlString = switch apiConfig.service {
314-
case .vertexAI:
315-
"wss://\(host)/ws/google.firebase.vertexai.v1beta.LlmBidiService/BidiGenerateContent/locations/us-central1"
313+
case let .vertexAI(_, location: location):
314+
"wss://\(host)/ws/google.firebase.vertexai.\(apiConfig.version.rawValue).LlmBidiService/BidiGenerateContent/locations/\(location)"
316315
case .googleAI:
317-
"wss://\(host)/ws/google.firebase.vertexai.v1beta.GenerativeService/BidiGenerateContent"
316+
"wss://\(host)/ws/google.firebase.vertexai.\(apiConfig.version.rawValue).GenerativeService/BidiGenerateContent"
318317
}
319318
guard let url = URL(string: urlString) else {
320319
throw NSError(

FirebaseAI/Sources/Types/Public/Backend.swift

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,28 @@ public struct Backend {
2525
/// for a list of supported locations.
2626
public static func vertexAI(location: String = "us-central1") -> Backend {
2727
return Backend(
28-
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta),
29-
location: location
28+
apiConfig: APIConfig(
29+
service: .vertexAI(endpoint: .firebaseProxyProd, location: location),
30+
version: .v1beta
31+
)
3032
)
3133
}
3234

3335
/// Initializes a `Backend` configured for the Google Developer API.
3436
public static func googleAI() -> Backend {
3537
return Backend(
36-
apiConfig: APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta),
37-
location: nil
38+
apiConfig: APIConfig(
39+
service: .googleAI(endpoint: .firebaseProxyProd),
40+
version: .v1beta
41+
)
3842
)
3943
}
4044

4145
// MARK: - Internal
4246

4347
let apiConfig: APIConfig
44-
let location: String?
4548

46-
init(apiConfig: APIConfig, location: String?) {
49+
init(apiConfig: APIConfig) {
4750
self.apiConfig = apiConfig
48-
self.location = location
4951
}
5052
}

FirebaseAI/Tests/TestApp/Tests/Utilities/InstanceConfig.swift

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,29 @@ import Testing
2121

2222
struct InstanceConfig: Equatable, Encodable {
2323
static let vertexAI_v1beta = InstanceConfig(
24-
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta)
24+
apiConfig: APIConfig(
25+
service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"),
26+
version: .v1beta
27+
)
2528
)
2629
static let vertexAI_v1beta_global = InstanceConfig(
27-
location: "global",
28-
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta)
30+
apiConfig: APIConfig(
31+
service: .vertexAI(endpoint: .firebaseProxyProd, location: "global"),
32+
version: .v1beta
33+
)
2934
)
3035
static let vertexAI_v1beta_global_appCheckLimitedUse = InstanceConfig(
31-
location: "global",
3236
useLimitedUseAppCheckTokens: true,
33-
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta)
37+
apiConfig: APIConfig(
38+
service: .vertexAI(endpoint: .firebaseProxyProd, location: "global"),
39+
version: .v1beta
40+
)
3441
)
3542
static let vertexAI_v1beta_staging = InstanceConfig(
36-
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyStaging), version: .v1beta)
43+
apiConfig: APIConfig(
44+
service: .vertexAI(endpoint: .firebaseProxyStaging, location: "us-central1"),
45+
version: .v1beta
46+
)
3747
)
3848
static let googleAI_v1beta = InstanceConfig(
3949
apiConfig: APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta)
@@ -68,12 +78,18 @@ struct InstanceConfig: Equatable, Encodable {
6878

6979
static let vertexAI_v1beta_appCheckNotConfigured = InstanceConfig(
7080
appName: FirebaseAppNames.appCheckNotConfigured,
71-
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta)
81+
apiConfig: APIConfig(
82+
service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"),
83+
version: .v1beta
84+
)
7285
)
7386
static let vertexAI_v1beta_appCheckNotConfigured_limitedUseTokens = InstanceConfig(
7487
appName: FirebaseAppNames.appCheckNotConfigured,
7588
useLimitedUseAppCheckTokens: true,
76-
apiConfig: APIConfig(service: .vertexAI(endpoint: .firebaseProxyProd), version: .v1beta)
89+
apiConfig: APIConfig(
90+
service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"),
91+
version: .v1beta
92+
)
7793
)
7894
static let googleAI_v1beta_appCheckNotConfigured = InstanceConfig(
7995
appName: FirebaseAppNames.appCheckNotConfigured,
@@ -93,16 +109,11 @@ struct InstanceConfig: Equatable, Encodable {
93109
]
94110

95111
let appName: String?
96-
let location: String?
97112
let useLimitedUseAppCheckTokens: Bool
98113
let apiConfig: APIConfig
99114

100-
init(appName: String? = nil,
101-
location: String? = nil,
102-
useLimitedUseAppCheckTokens: Bool = false,
103-
apiConfig: APIConfig) {
115+
init(appName: String? = nil, useLimitedUseAppCheckTokens: Bool = false, apiConfig: APIConfig) {
104116
self.appName = appName
105-
self.location = location
106117
self.useLimitedUseAppCheckTokens = useLimitedUseAppCheckTokens
107118
self.apiConfig = apiConfig
108119
}
@@ -136,7 +147,12 @@ extension InstanceConfig: CustomTestStringConvertible {
136147
case .googleAIBypassProxy:
137148
" - Bypass Proxy"
138149
}
139-
let locationSuffix = location.map { " - \($0)" } ?? ""
150+
let locationSuffix: String
151+
if case let .vertexAI(_, location: location) = apiConfig.service {
152+
locationSuffix = location
153+
} else {
154+
locationSuffix = ""
155+
}
140156
let appCheckLimitedUseDesignator = useLimitedUseAppCheckTokens ? " - FAC Limited-Use" : ""
141157

142158
return """
@@ -150,21 +166,14 @@ extension FirebaseAI {
150166
static func componentInstance(_ instanceConfig: InstanceConfig) -> FirebaseAI {
151167
switch instanceConfig.apiConfig.service {
152168
case .vertexAI:
153-
let location = instanceConfig.location ?? "us-central1"
154169
return FirebaseAI.createInstance(
155170
app: instanceConfig.app,
156-
location: location,
157171
apiConfig: instanceConfig.apiConfig,
158172
useLimitedUseAppCheckTokens: instanceConfig.useLimitedUseAppCheckTokens
159173
)
160174
case .googleAI:
161-
assert(
162-
instanceConfig.location == nil,
163-
"The Developer API is global and does not support `location`."
164-
)
165175
return FirebaseAI.createInstance(
166176
app: instanceConfig.app,
167-
location: nil,
168177
apiConfig: instanceConfig.apiConfig,
169178
useLimitedUseAppCheckTokens: instanceConfig.useLimitedUseAppCheckTokens
170179
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
@testable import FirebaseAI
16+
17+
extension FirebaseAI {
18+
static let defaultVertexAIAPIConfig = APIConfig(
19+
service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"),
20+
version: .v1beta
21+
)
22+
}

FirebaseAI/Tests/Unit/Types/BackendTests.swift

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,25 @@ import XCTest
1919
final class BackendTests: XCTestCase {
2020
func testVertexAI_defaultLocation() {
2121
let expectedAPIConfig = APIConfig(
22-
service: .vertexAI(endpoint: .firebaseProxyProd),
22+
service: .vertexAI(endpoint: .firebaseProxyProd, location: "us-central1"),
2323
version: .v1beta
2424
)
2525

2626
let backend = Backend.vertexAI()
2727

2828
XCTAssertEqual(backend.apiConfig, expectedAPIConfig)
29-
XCTAssertEqual(backend.location, "us-central1")
3029
}
3130

3231
func testVertexAI_customLocation() {
32+
let customLocation = "europe-west1"
3333
let expectedAPIConfig = APIConfig(
34-
service: .vertexAI(endpoint: .firebaseProxyProd),
34+
service: .vertexAI(endpoint: .firebaseProxyProd, location: customLocation),
3535
version: .v1beta
3636
)
37-
let customLocation = "europe-west1"
3837

3938
let backend = Backend.vertexAI(location: customLocation)
4039

4140
XCTAssertEqual(backend.apiConfig, expectedAPIConfig)
42-
XCTAssertEqual(backend.location, customLocation)
4341
}
4442

4543
func testGoogleAI() {
@@ -51,6 +49,5 @@ final class BackendTests: XCTestCase {
5149
let backend = Backend.googleAI()
5250

5351
XCTAssertEqual(backend.apiConfig, expectedAPIConfig)
54-
XCTAssertNil(backend.location)
5552
}
5653
}

0 commit comments

Comments
 (0)