Skip to content

Commit 92f6a1d

Browse files
committed
Use forked google-generative-ai in Vertex AI samples (#12574)
1 parent 50d0d29 commit 92f6a1d

File tree

9 files changed

+48
-125
lines changed

9 files changed

+48
-125
lines changed

FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,6 @@ struct ErrorDetailsView: View {
142142
SafetyRatingsSection(ratings: ratings)
143143
}
144144

145-
case GenerateContentError.invalidAPIKey:
146-
Section("Error Type") {
147-
Text("Invalid API Key")
148-
}
149-
150-
Section("Details") {
151-
SubtitleFormRow(title: "Error description", value: error.localizedDescription)
152-
SubtitleMarkdownFormRow(
153-
title: "Help",
154-
value: """
155-
The `API_KEY` provided in the `GoogleService-Info.plist` file is invalid. Download a
156-
new copy of the file from the [Firebase Console](https://console.firebase.google.com).
157-
"""
158-
)
159-
}
160-
161145
default:
162146
Section("Error Type") {
163147
Text("Some other error")
@@ -222,11 +206,3 @@ struct ErrorDetailsView: View {
222206

223207
return ErrorDetailsView(error: error)
224208
}
225-
226-
#Preview("Invalid API Key") {
227-
ErrorDetailsView(error: GenerateContentError.invalidAPIKey)
228-
}
229-
230-
#Preview("Unsupported User Location") {
231-
ErrorDetailsView(error: GenerateContentError.unsupportedUserLocation)
232-
}

FirebaseVertexAI/Sources/Errors.swift

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,6 @@ struct RPCError: Error {
3030
self.status = status
3131
self.details = details
3232
}
33-
34-
func isInvalidAPIKeyError() -> Bool {
35-
return errorInfo?.reason == "API_KEY_INVALID"
36-
}
37-
38-
func isUnsupportedUserLocationError() -> Bool {
39-
return message == RPCErrorMessage.unsupportedUserLocation.rawValue
40-
}
4133
}
4234

4335
extension RPCError: Decodable {
@@ -179,10 +171,6 @@ enum RPCStatus: String, Decodable {
179171
case dataLoss = "DATA_LOSS"
180172
}
181173

182-
enum RPCErrorMessage: String {
183-
case unsupportedUserLocation = "User location is not supported for the API use."
184-
}
185-
186174
enum InvalidCandidateError: Error {
187175
case emptyContent(underlyingError: Error)
188176
case malformedContent(underlyingError: Error)

FirebaseVertexAI/Sources/GenerateContentError.swift

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,4 @@ public enum GenerateContentError: Error {
2828

2929
/// A response didn't fully complete. See the `FinishReason` for more information.
3030
case responseStoppedEarly(reason: FinishReason, response: GenerateContentResponse)
31-
32-
/// The provided API key is invalid.
33-
case invalidAPIKey(message: String)
34-
35-
/// The user's location (region) is not supported by the API.
36-
///
37-
/// See the Google documentation for a
38-
/// [list of regions](https://ai.google.dev/available_regions#available_regions)
39-
/// (countries and territories) where the API is available.
40-
///
41-
/// - Important: The API is only available in
42-
/// [specific regions](https://ai.google.dev/available_regions#available_regions).
43-
case unsupportedUserLocation
4431
}

FirebaseVertexAI/Sources/GenerativeAIRequest.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ public struct RequestOptions {
3838
/// - Parameters:
3939
/// - timeout The request’s timeout interval in seconds; if not specified uses the default value
4040
/// for a `URLRequest`.
41-
/// - apiVersion The API version to use in requests to the backend; defaults to "v1".
42-
public init(timeout: TimeInterval? = nil, apiVersion: String = "v1") {
41+
/// - apiVersion The API version to use in requests to the backend; defaults to "v2beta".
42+
public init(timeout: TimeInterval? = nil, apiVersion: String = "v2beta") {
4343
self.timeout = timeout
4444
self.apiVersion = apiVersion
4545
}

FirebaseVertexAI/Sources/GenerativeAIService.swift

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,27 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import FirebaseAppCheckInterop
16+
import FirebaseCore
1517
import Foundation
1618

1719
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1820
struct GenerativeAIService {
1921
/// Gives permission to talk to the backend.
2022
private let apiKey: String
2123

24+
private let appCheck: AppCheckInterop?
25+
2226
private let urlSession: URLSession
2327

24-
init(apiKey: String, urlSession: URLSession) {
28+
init(apiKey: String, appCheck: AppCheckInterop?, urlSession: URLSession) {
2529
self.apiKey = apiKey
30+
self.appCheck = appCheck
2631
self.urlSession = urlSession
2732
}
2833

2934
func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
30-
let urlRequest = try urlRequest(request: request)
35+
let urlRequest = try await urlRequest(request: request)
3136

3237
#if DEBUG
3338
printCURLCommand(from: urlRequest)
@@ -59,7 +64,7 @@ struct GenerativeAIService {
5964
Task {
6065
let urlRequest: URLRequest
6166
do {
62-
urlRequest = try self.urlRequest(request: request)
67+
urlRequest = try await self.urlRequest(request: request)
6368
} catch {
6469
continuation.finish(throwing: error)
6570
return
@@ -146,13 +151,24 @@ struct GenerativeAIService {
146151

147152
// MARK: - Private Helpers
148153

149-
private func urlRequest<T: GenerativeAIRequest>(request: T) throws -> URLRequest {
154+
private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
150155
var urlRequest = URLRequest(url: request.url)
151156
urlRequest.httpMethod = "POST"
152157
urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
153-
urlRequest.setValue("genai-swift/\(GenerativeAISwift.version)",
154-
forHTTPHeaderField: "x-goog-api-client")
158+
// TODO: Determine the right client header to use.
159+
// urlRequest.setValue("genai-swift/\(GenerativeAISwift.version))",
160+
// forHTTPHeaderField: "x-goog-api-client")
155161
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
162+
163+
if let appCheck {
164+
let tokenResult = await appCheck.getToken(forcingRefresh: false)
165+
urlRequest.setValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
166+
if let error = tokenResult.error {
167+
Logging.default
168+
.debug("[GoogleGenerativeAI] Failed to fetch AppCheck token. Error: \(error)")
169+
}
170+
}
171+
156172
let encoder = JSONEncoder()
157173
encoder.keyEncodingStrategy = .convertToSnakeCase
158174
urlRequest.httpBody = try encoder.encode(request)

FirebaseVertexAI/Sources/GenerativeAISwift.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
import Foundation
1516

1617
#if !os(macOS) && !os(iOS)
@@ -20,7 +21,5 @@ import Foundation
2021
/// Constants associated with the GenerativeAISwift SDK
2122
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
2223
public enum GenerativeAISwift {
23-
/// String value of the SDK version
24-
public static let version = "0.4.8"
25-
static let baseURL = "https://generativelanguage.googleapis.com"
24+
static let baseURL = "https://staging-firebaseml.sandbox.googleapis.com"
2625
}

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import FirebaseAppCheckInterop
1516
import Foundation
1617

1718
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
@@ -44,31 +45,21 @@ public final class GenerativeModel {
4445
/// - apiKey: The API key for your project.
4546
/// - generationConfig: The content generation parameters your model should use.
4647
/// - safetySettings: A value describing what types of harmful content your model should allow.
47-
/// - requestOptions Configuration parameters for sending requests to the backend.
48-
public convenience init(name: String,
49-
apiKey: String,
50-
generationConfig: GenerationConfig? = nil,
51-
safetySettings: [SafetySetting]? = nil,
52-
requestOptions: RequestOptions = RequestOptions()) {
53-
self.init(
54-
name: name,
55-
apiKey: apiKey,
56-
generationConfig: generationConfig,
57-
safetySettings: safetySettings,
58-
requestOptions: requestOptions,
59-
urlSession: .shared
60-
)
61-
}
62-
63-
/// The designated initializer for this class.
48+
/// - requestOptions: Configuration parameters for sending requests to the backend.
49+
/// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
6450
init(name: String,
6551
apiKey: String,
6652
generationConfig: GenerationConfig? = nil,
6753
safetySettings: [SafetySetting]? = nil,
68-
requestOptions: RequestOptions = RequestOptions(),
69-
urlSession: URLSession) {
54+
requestOptions: RequestOptions,
55+
appCheck: AppCheckInterop?,
56+
urlSession: URLSession = .shared) {
7057
modelResourceName = GenerativeModel.modelResourceName(name: name)
71-
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
58+
generativeAIService = GenerativeAIService(
59+
apiKey: apiKey,
60+
appCheck: appCheck,
61+
urlSession: urlSession
62+
)
7263
self.generationConfig = generationConfig
7364
self.safetySettings = safetySettings
7465
self.requestOptions = requestOptions
@@ -282,10 +273,6 @@ public final class GenerativeModel {
282273
private static func generateContentError(from error: Error) -> GenerateContentError {
283274
if let error = error as? GenerateContentError {
284275
return error
285-
} else if let error = error as? RPCError, error.isInvalidAPIKeyError() {
286-
return GenerateContentError.invalidAPIKey(message: error.message)
287-
} else if let error = error as? RPCError, error.isUnsupportedUserLocationError() {
288-
return GenerateContentError.unsupportedUserLocation
289276
}
290277
return GenerateContentError.internalError(underlying: error)
291278
}

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import Foundation
16-
1715
import FirebaseAppCheckInterop
1816
import FirebaseCore
19-
20-
// Exports the GoogleGenerativeAI module to users of the SDK.
21-
@_exported import GoogleGenerativeAI
17+
import Foundation
2218

2319
// Avoids exposing internal FirebaseCore APIs to Swift users.
2420
@_implementationOnly import FirebaseCoreExtension
@@ -31,17 +27,20 @@ open class VertexAI: NSObject {
3127
/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
3228
///
3329
/// This instance is configured with the default `FirebaseApp`.
34-
public static func generativeModel(modelName: String, location: String) -> GoogleGenerativeAI
35-
.GenerativeModel {
30+
///
31+
/// TODO: Add RequestOptions to public API.
32+
public static func generativeModel(modelName: String, location: String) -> GenerativeModel {
3633
guard let app = FirebaseApp.app() else {
3734
fatalError("No instance of the default Firebase app was found.")
3835
}
3936
return generativeModel(app: app, modelName: modelName, location: location)
4037
}
4138

4239
/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
40+
///
41+
/// TODO: Add RequestOptions to public API.
4342
public static func generativeModel(app: FirebaseApp, modelName: String,
44-
location: String) -> GoogleGenerativeAI.GenerativeModel {
43+
location: String) -> GenerativeModel {
4544
guard let provider = ComponentType<VertexAIProvider>.instance(for: VertexAIProvider.self,
4645
in: app.container) else {
4746
fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)")
@@ -64,18 +63,15 @@ open class VertexAI: NSObject {
6463
private let modelResouceName: String
6564

6665
lazy var model: GenerativeModel = {
67-
let options = RequestOptions(
68-
apiVersion: "v2beta",
69-
endpoint: "staging-firebaseml.sandbox.googleapis.com",
70-
hooks: [addAppCheckHeader]
71-
)
7266
guard let apiKey = app.options.apiKey else {
7367
fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.")
7468
}
7569
return GenerativeModel(
7670
name: modelResouceName,
7771
apiKey: apiKey,
78-
requestOptions: options
72+
// TODO: Add RequestOptions to public API.
73+
requestOptions: RequestOptions(),
74+
appCheck: appCheck
7975
)
8076
}()
8177

@@ -104,21 +100,4 @@ open class VertexAI: NSObject {
104100

105101
return "projects/\(projectID)/locations/\(location)/publishers/google/models/\(modelName)"
106102
}
107-
108-
// MARK: Request Hooks
109-
110-
/// Adds an App Check token to the provided request if App Check is included in the app.
111-
///
112-
/// This demonstrates how an App Check token can be added to requests; it is currently ignored by
113-
/// the backend.
114-
///
115-
/// - Parameter request: The `URLRequest` to modify by adding an App Check token header.
116-
func addAppCheckHeader(request: inout URLRequest) async {
117-
guard let appCheck else {
118-
return
119-
}
120-
121-
let tokenResult = await appCheck.getToken(forcingRefresh: false)
122-
request.addValue(tokenResult.token, forHTTPHeaderField: "X-Firebase-AppCheck")
123-
}
124103
}

Package.swift

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,6 @@ let package = Package(
192192
"100.0.0" ..< "101.0.0"
193193
),
194194
.package(url: "https://github.com/google/app-check.git", "10.19.0" ..< "11.0.0"),
195-
.package(
196-
url: "https://github.com/google/generative-ai-swift.git",
197-
revision: "c9f2c4913bc65aa267815962c7e91358c2d8463f"
198-
),
199195
],
200196
targets: [
201197
.target(
@@ -1374,13 +1370,8 @@ let package = Package(
13741370
"FirebaseAppCheckInterop",
13751371
"FirebaseCore",
13761372
"FirebaseCoreExtension",
1377-
.product(name: "GoogleGenerativeAI", package: "generative-ai-swift"),
13781373
],
1379-
path: "FirebaseVertexAI/Sources",
1380-
sources: [
1381-
"VertexAI.swift",
1382-
"VertexAIComponent.swift",
1383-
]
1374+
path: "FirebaseVertexAI/Sources"
13841375
),
13851376
] + firestoreTargets(),
13861377
cLanguageStandard: .c99,

0 commit comments

Comments
 (0)