Skip to content

Commit af52205

Browse files
committed
Make generativeModel an instance method of VertexAI (#12599)
1 parent 4f44ef2 commit af52205

File tree

6 files changed

+85
-70
lines changed

6 files changed

+85
-70
lines changed

FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ class ConversationViewModel: ObservableObject {
3636
private var chatTask: Task<Void, Never>?
3737

3838
init() {
39-
model = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
39+
model = VertexAI.vertexAI().generativeModel(
40+
modelName: "gemini-1.0-pro",
41+
location: "us-central1"
42+
)
4043
chat = model.startChat()
4144
}
4245

FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ class PhotoReasoningViewModel: ObservableObject {
4444
private var model: GenerativeModel?
4545

4646
init() {
47-
model = VertexAI.generativeModel(modelName: "gemini-1.0-pro-vision", location: "us-central1")
47+
model = VertexAI.vertexAI().generativeModel(
48+
modelName: "gemini-1.0-pro-vision",
49+
location: "us-central1"
50+
)
4851
}
4952

5053
func reason() async {

FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ class SummarizeViewModel: ObservableObject {
3232
private var model: GenerativeModel?
3333

3434
init() {
35-
model = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
35+
model = VertexAI.vertexAI().generativeModel(
36+
modelName: "gemini-1.0-pro",
37+
location: "us-central1"
38+
)
3639
}
3740

3841
func summarize(inputText: String) async {

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,70 +20,81 @@ import Foundation
2020
@_implementationOnly import FirebaseCoreExtension
2121

2222
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
23-
@objc(FIRVertexAI)
24-
open class VertexAI: NSObject {
23+
public class VertexAI: NSObject {
2524
// MARK: - Public APIs
2625

27-
/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
26+
/// The default `VertexAI` instance.
2827
///
29-
/// This instance is configured with the default `FirebaseApp`.
30-
///
31-
/// TODO: Add RequestOptions to public API.
32-
public static func generativeModel(modelName: String, location: String) -> GenerativeModel {
28+
/// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`.
29+
public static func vertexAI() -> VertexAI {
3330
guard let app = FirebaseApp.app() else {
3431
fatalError("No instance of the default Firebase app was found.")
3532
}
36-
return generativeModel(app: app, modelName: modelName, location: location)
33+
34+
return vertexAI(app: app)
3735
}
3836

39-
/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
37+
/// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`.
4038
///
41-
/// TODO: Add RequestOptions to public API.
42-
public static func generativeModel(app: FirebaseApp, modelName: String,
43-
location: String) -> GenerativeModel {
39+
/// - Parameter app: The custom `FirebaseApp` used for initialization.
40+
/// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`.
41+
public static func vertexAI(app: FirebaseApp) -> VertexAI {
4442
guard let provider = ComponentType<VertexAIProvider>.instance(for: VertexAIProvider.self,
4543
in: app.container) else {
4644
fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)")
4745
}
48-
let modelResourceName = modelResourceName(app: app, modelName: modelName, location: location)
49-
let vertexAI = provider.vertexAI(location: location, modelResourceName: modelResourceName)
5046

51-
return vertexAI.model
47+
return provider.vertexAI()
5248
}
5349

54-
// MARK: - Private
55-
56-
/// The `FirebaseApp` associated with this `VertexAI` instance.
57-
private let app: FirebaseApp
58-
59-
private let appCheck: AppCheckInterop?
60-
61-
private let location: String
62-
63-
private let modelResouceName: String
50+
/// Initializes a generative model with the given parameters.
51+
///
52+
/// - Parameters:
53+
/// - modelName: The name of the model to use, e.g., `"gemini-1.0-pro"`; see
54+
/// [Gemini
55+
/// models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models)
56+
/// for a list of supported model names.
57+
/// - location: The location identifier, e.g., `us-central1`; see
58+
/// [Vertex AI
59+
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
60+
/// for a list of supported locations.
61+
/// - generationConfig: The content generation parameters your model should use.
62+
/// - safetySettings: A value describing what types of harmful content your model should allow.
63+
/// - requestOptions: Configuration parameters for sending requests to the backend.
64+
public func generativeModel(modelName: String, location: String,
65+
generationConfig: GenerationConfig? = nil,
66+
safetySettings: [SafetySetting]? = nil,
67+
requestOptions: RequestOptions = RequestOptions())
68+
-> GenerativeModel {
69+
let modelResourceName = modelResourceName(modelName: modelName, location: location)
6470

65-
lazy var model: GenerativeModel = {
6671
guard let apiKey = app.options.apiKey else {
6772
fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.")
6873
}
74+
6975
return GenerativeModel(
70-
name: modelResouceName,
76+
name: modelResourceName,
7177
apiKey: apiKey,
72-
// TODO: Add RequestOptions to public API.
73-
requestOptions: RequestOptions(),
78+
generationConfig: generationConfig,
79+
safetySettings: safetySettings,
80+
requestOptions: requestOptions,
7481
appCheck: appCheck
7582
)
76-
}()
83+
}
84+
85+
// MARK: - Private
86+
87+
/// The `FirebaseApp` associated with this `VertexAI` instance.
88+
private let app: FirebaseApp
89+
90+
private let appCheck: AppCheckInterop?
7791

78-
init(app: FirebaseApp, location: String, modelResourceName: String) {
92+
init(app: FirebaseApp) {
7993
self.app = app
8094
appCheck = ComponentType<AppCheckInterop>.instance(for: AppCheckInterop.self, in: app.container)
81-
self.location = location
82-
modelResouceName = modelResourceName
8395
}
8496

85-
private static func modelResourceName(app: FirebaseApp, modelName: String,
86-
location: String) -> String {
97+
private func modelResourceName(modelName: String, location: String) -> String {
8798
if modelName.contains("/") {
8899
return modelName
89100
}

FirebaseVertexAI/Sources/VertexAIComponent.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import Foundation
2222
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
2323
@objc(FIRVertexAIProvider)
2424
protocol VertexAIProvider {
25-
@objc func vertexAI(location: String, modelResourceName: String) -> VertexAI
25+
@objc func vertexAI() -> VertexAI
2626
}
2727

2828
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
@@ -64,17 +64,17 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {
6464

6565
// MARK: - VertexAIProvider conformance
6666

67-
func vertexAI(location: String, modelResourceName: String) -> VertexAI {
67+
func vertexAI() -> VertexAI {
6868
os_unfair_lock_lock(&instancesLock)
6969

7070
// Unlock before the function returns.
7171
defer { os_unfair_lock_unlock(&instancesLock) }
7272

73-
if let instance = instances[modelResourceName] {
73+
if let instance = instances[app.name] {
7474
return instance
7575
}
76-
let newInstance = VertexAI(app: app, location: location, modelResourceName: modelResourceName)
77-
instances[modelResourceName] = newInstance
76+
let newInstance = VertexAI(app: app)
77+
instances[app.name] = newInstance
7878
return newInstance
7979
}
8080
}

FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,40 +33,35 @@ final class VertexAIAPITests: XCTestCase {
3333
stopSequences: ["..."])
3434
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
3535

36+
// Instantiate Vertex AI SDK - Default App
37+
let vertexAI = VertexAI.vertexAI()
38+
39+
// Instantiate Vertex AI SDK - Custom App
40+
let _ = VertexAI.vertexAI(app: app!)
41+
3642
// Permutations without optional arguments.
3743

38-
// TODO: Change `genAI` to `_` when safetySettings and generationConfig are added to public API.
39-
let genAI = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
40-
let _ = VertexAI.generativeModel(
41-
app: app!,
44+
let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
45+
46+
let _ = vertexAI.generativeModel(
4247
modelName: "gemini-1.0-pro",
43-
location: "us-central1"
48+
location: "us-central1",
49+
safetySettings: filters
4450
)
4551

46-
// TODO: Add safetySettings to public API.
47-
// TODO: Add permutation with `app` specified.
48-
// let _ = VertexAI.generativeModel(
49-
// modelName: "gemini-1.0-pro",
50-
// location: "us-central1",
51-
// safetySettings: filters
52-
// )
53-
// TODO: Add generationConfig to public API.
54-
// TODO: Add permutation with `app` specified.
55-
// let _ = VertexAI.generativeModel(
56-
// modelName: "gemini-1.0-pro",
57-
// location: "us-central1",
58-
// generationConfig: config
59-
// )
52+
let _ = vertexAI.generativeModel(
53+
modelName: "gemini-1.0-pro",
54+
location: "us-central1",
55+
generationConfig: config
56+
)
6057

6158
// All arguments passed.
62-
// TODO: Add safetySettings and generationConfig to public API.
63-
// TODO: Add permutation with `app` specified.
64-
// let genAI = VertexAI.generativeModel(
65-
// modelName: "gemini-1.0-pro",
66-
// location: "us-central1",
67-
// generationConfig: config, // Optional
68-
// safetySettings: filters // Optional
69-
// )
59+
let genAI = vertexAI.generativeModel(
60+
modelName: "gemini-1.0-pro",
61+
location: "us-central1",
62+
generationConfig: config, // Optional
63+
safetySettings: filters // Optional
64+
)
7065

7166
// Full Typed Usage
7267
let pngData = Data() // ....

0 commit comments

Comments
 (0)