Skip to content

Commit 23a1feb

Browse files
committed
Use generative-ai-swift tests in Vertex AI (#12585)
1 parent f2760d4 commit 23a1feb

11 files changed

+178
-130
lines changed

.github/workflows/vertexai.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ jobs:
3131
- name: Xcode
3232
run: sudo xcode-select -s /Applications/${{ matrix.xcode }}.app/Contents/Developer
3333
- name: Initialize xcodebuild
34-
run: xcodebuild -list
35-
# TODO: Add unit tests and switch from `spmbuildonly` to `spm`.
36-
- name: Build
37-
run: scripts/third_party/travis/retry.sh scripts/build.sh FirebaseVertexAI ${{ matrix.target }} spmbuildonly
34+
run: scripts/setup_spm_tests.sh
35+
- name: Build and run tests
36+
run: scripts/third_party/travis/retry.sh scripts/build.sh FirebaseVertexAIUnit ${{ matrix.target }} spm

FirebaseVertexAI/Tests/Unit/ChatTests.swift

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
// limitations under the License.
1414

1515
import Foundation
16-
@testable import GoogleGenerativeAI
1716
import XCTest
1817

18+
@testable import FirebaseVertexAI
19+
1920
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
2021
final class ChatTests: XCTestCase {
2122
var urlSession: URLSession!
@@ -46,7 +47,13 @@ final class ChatTests: XCTestCase {
4647
return (response, fileURL.lines)
4748
}
4849

49-
let model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
50+
let model = GenerativeModel(
51+
name: "my-model",
52+
apiKey: "API_KEY",
53+
requestOptions: RequestOptions(),
54+
appCheck: nil,
55+
urlSession: urlSession
56+
)
5057
let chat = Chat(model: model, history: [])
5158
let input = "Test input"
5259
let stream = chat.sendMessageStream(input)

FirebaseVertexAI/Tests/Unit/GenerateContentResponses/streaming-success-citations2.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

FirebaseVertexAI/Tests/Unit/GenerateContentResponses/unary-failure-api-key.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"reason": "API_KEY_INVALID",
1010
"domain": "googleapis.com",
1111
"metadata": {
12-
"service": "generativelanguage.googleapis.com"
12+
"service": "staging-firebaseml.sandbox.googleapis.com"
1313
}
1414
},
1515
{

FirebaseVertexAI/Tests/Unit/GenerateContentResponses/unary-failure-finish-reason-recitation-no-content.json

Lines changed: 0 additions & 46 deletions
This file was deleted.

FirebaseVertexAI/Tests/Unit/GenerateContentResponses/unary-failure-unsupported-user-location.json

Lines changed: 0 additions & 13 deletions
This file was deleted.

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
@testable import GoogleGenerativeAI
1615
import XCTest
1716

17+
@testable import FirebaseVertexAI
18+
1819
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
1920
final class GenerativeModelTests: XCTestCase {
2021
let testPrompt = "What sorts of questions can I ask you?"
@@ -32,7 +33,13 @@ final class GenerativeModelTests: XCTestCase {
3233
let configuration = URLSessionConfiguration.default
3334
configuration.protocolClasses = [MockURLProtocol.self]
3435
urlSession = try XCTUnwrap(URLSession(configuration: configuration))
35-
model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
36+
model = GenerativeModel(
37+
name: "my-model",
38+
apiKey: "API_KEY",
39+
requestOptions: RequestOptions(),
40+
appCheck: nil,
41+
urlSession: urlSession
42+
)
3643
}
3744

3845
override func tearDown() {
@@ -163,6 +170,8 @@ final class GenerativeModelTests: XCTestCase {
163170
// Model name is prefixed with "models/".
164171
name: "models/test-model",
165172
apiKey: "API_KEY",
173+
requestOptions: RequestOptions(),
174+
appCheck: nil,
166175
urlSession: urlSession
167176
)
168177

@@ -181,10 +190,13 @@ final class GenerativeModelTests: XCTestCase {
181190
do {
182191
_ = try await model.generateContent(testPrompt)
183192
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
184-
} catch let GenerateContentError.invalidAPIKey(message) {
185-
XCTAssertEqual(message, "API key not valid. Please pass a valid API key.")
193+
} catch let GenerateContentError.internalError(error as RPCError) {
194+
XCTAssertEqual(error.httpResponseCode, 400)
195+
XCTAssertEqual(error.status, .invalidArgument)
196+
XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
197+
return
186198
} catch {
187-
XCTFail("Should throw GenerateContentError.invalidAPIKey; error thrown: \(error)")
199+
XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
188200
}
189201
}
190202

@@ -342,24 +354,6 @@ final class GenerativeModelTests: XCTestCase {
342354
}
343355
}
344356

345-
func testGenerateContent_failure_unsupportedUserLocation() async throws {
346-
MockURLProtocol
347-
.requestHandler = try httpRequestHandler(
348-
forResource: "unary-failure-unsupported-user-location",
349-
withExtension: "json",
350-
statusCode: 400
351-
)
352-
353-
do {
354-
_ = try await model.generateContent(testPrompt)
355-
XCTFail("Should throw GenerateContentError.unsupportedUserLocation; no error thrown.")
356-
} catch GenerateContentError.unsupportedUserLocation {
357-
return
358-
}
359-
360-
XCTFail("Expected an unsupported user location error.")
361-
}
362-
363357
func testGenerateContent_failure_nonHTTPResponse() async throws {
364358
MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
365359

@@ -468,6 +462,7 @@ final class GenerativeModelTests: XCTestCase {
468462
name: "my-model",
469463
apiKey: "API_KEY",
470464
requestOptions: requestOptions,
465+
appCheck: nil,
471466
urlSession: urlSession
472467
)
473468

@@ -490,8 +485,10 @@ final class GenerativeModelTests: XCTestCase {
490485
for try await _ in stream {
491486
XCTFail("No content is there, this shouldn't happen.")
492487
}
493-
} catch GenerateContentError.invalidAPIKey {
494-
// invalidAPIKey error is as expected, nothing else to check.
488+
} catch let GenerateContentError.internalError(error as RPCError) {
489+
XCTAssertEqual(error.httpResponseCode, 400)
490+
XCTAssertEqual(error.status, .invalidArgument)
491+
XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
495492
return
496493
}
497494

@@ -747,26 +744,6 @@ final class GenerativeModelTests: XCTestCase {
747744
XCTFail("Expected an internal decoding error.")
748745
}
749746

750-
func testGenerateContentStream_failure_unsupportedUserLocation() async throws {
751-
MockURLProtocol
752-
.requestHandler = try httpRequestHandler(
753-
forResource: "unary-failure-unsupported-user-location",
754-
withExtension: "json",
755-
statusCode: 400
756-
)
757-
758-
let stream = model.generateContentStream(testPrompt)
759-
do {
760-
for try await content in stream {
761-
XCTFail("Unexpected content in stream: \(content)")
762-
}
763-
} catch GenerateContentError.unsupportedUserLocation {
764-
return
765-
}
766-
767-
XCTFail("Expected an unsupported user location error.")
768-
}
769-
770747
func testGenerateContentStream_requestOptions_customTimeout() async throws {
771748
let expectedTimeout = 150.0
772749
MockURLProtocol
@@ -780,6 +757,7 @@ final class GenerativeModelTests: XCTestCase {
780757
name: "my-model",
781758
apiKey: "API_KEY",
782759
requestOptions: requestOptions,
760+
appCheck: nil,
783761
urlSession: urlSession
784762
)
785763

@@ -837,6 +815,7 @@ final class GenerativeModelTests: XCTestCase {
837815
name: "my-model",
838816
apiKey: "API_KEY",
839817
requestOptions: requestOptions,
818+
appCheck: nil,
840819
urlSession: urlSession
841820
)
842821

@@ -851,23 +830,38 @@ final class GenerativeModelTests: XCTestCase {
851830
let modelName = "my-model"
852831
let modelResourceName = "models/\(modelName)"
853832

854-
model = GenerativeModel(name: modelName, apiKey: "API_KEY")
833+
model = GenerativeModel(
834+
name: modelName,
835+
apiKey: "API_KEY",
836+
requestOptions: RequestOptions(),
837+
appCheck: nil
838+
)
855839

856840
XCTAssertEqual(model.modelResourceName, modelResourceName)
857841
}
858842

859843
func testModelResourceName_modelsPrefix() async throws {
860844
let modelResourceName = "models/my-model"
861845

862-
model = GenerativeModel(name: modelResourceName, apiKey: "API_KEY")
846+
model = GenerativeModel(
847+
name: modelResourceName,
848+
apiKey: "API_KEY",
849+
requestOptions: RequestOptions(),
850+
appCheck: nil
851+
)
863852

864853
XCTAssertEqual(model.modelResourceName, modelResourceName)
865854
}
866855

867856
func testModelResourceName_tunedModelsPrefix() async throws {
868857
let tunedModelResourceName = "tunedModels/my-model"
869858

870-
model = GenerativeModel(name: tunedModelResourceName, apiKey: "API_KEY")
859+
model = GenerativeModel(
860+
name: tunedModelResourceName,
861+
apiKey: "API_KEY",
862+
requestOptions: RequestOptions(),
863+
appCheck: nil
864+
)
871865

872866
XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
873867
}

FirebaseVertexAI/Tests/Unit/PartsRepresentableTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import CoreGraphics
1616
import CoreImage
17-
import GoogleGenerativeAI
17+
import FirebaseVertexAI
1818
import XCTest
1919
#if canImport(UIKit)
2020
import UIKit

FirebaseVertexAI/Tests/Unit/GoogleAITests.swift renamed to FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import GoogleGenerativeAI
15+
import FirebaseCore
16+
import FirebaseVertexAI
1617
import XCTest
1718
#if canImport(AppKit)
1819
import AppKit // For NSImage extensions.
@@ -21,8 +22,9 @@ import XCTest
2122
#endif
2223

2324
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
24-
final class GoogleGenerativeAITests: XCTestCase {
25+
final class VertexAIAPITests: XCTestCase {
2526
func codeSamples() async throws {
27+
let app = FirebaseApp.app()
2628
let config = GenerationConfig(temperature: 0.2,
2729
topP: 0.1,
2830
topK: 16,
@@ -32,16 +34,40 @@ final class GoogleGenerativeAITests: XCTestCase {
3234
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
3335

3436
// Permutations without optional arguments.
35-
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY")
36-
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", safetySettings: filters)
37-
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", generationConfig: config)
3837

39-
// All arguments passed.
40-
let genAI = GenerativeModel(name: "gemini-1.0-pro",
41-
apiKey: "API_KEY",
42-
generationConfig: config, // Optional
43-
safetySettings: filters // Optional
38+
// TODO: Change `genAI` to `_` when safetySettings and generationConfig are added to public API.
39+
let genAI = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
40+
let _ = VertexAI.generativeModel(
41+
app: app!,
42+
modelName: "gemini-1.0-pro",
43+
location: "us-central1"
4444
)
45+
46+
// TODO: Add safetySettings to public API.
47+
// TODO: Add permutation with `app` specified.
48+
// let _ = VertexAI.generativeModel(
49+
// modelName: "gemini-1.0-pro",
50+
// location: "us-central1",
51+
// safetySettings: filters
52+
// )
53+
// TODO: Add generationConfig to public API.
54+
// TODO: Add permutation with `app` specified.
55+
// let _ = VertexAI.generativeModel(
56+
// modelName: "gemini-1.0-pro",
57+
// location: "us-central1",
58+
// generationConfig: config
59+
// )
60+
61+
// All arguments passed.
62+
// TODO: Add safetySettings and generationConfig to public API.
63+
// TODO: Add permutation with `app` specified.
64+
// let genAI = VertexAI.generativeModel(
65+
// modelName: "gemini-1.0-pro",
66+
// location: "us-central1",
67+
// generationConfig: config, // Optional
68+
// safetySettings: filters // Optional
69+
// )
70+
4571
// Full Typed Usage
4672
let pngData = Data() // ....
4773
let contents = [ModelContent(role: "user",

Package.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,15 @@ let package = Package(
13731373
],
13741374
path: "FirebaseVertexAI/Sources"
13751375
),
1376+
.testTarget(
1377+
name: "FirebaseVertexAIUnit",
1378+
dependencies: ["FirebaseVertexAI"],
1379+
path: "FirebaseVertexAI/Tests/Unit",
1380+
resources: [
1381+
.process("CountTokenResponses"),
1382+
.process("GenerateContentResponses"),
1383+
]
1384+
),
13761385
] + firestoreTargets(),
13771386
cLanguageStandard: .c99,
13781387
cxxLanguageStandard: CXXLanguageStandard.gnucxx14

0 commit comments

Comments
 (0)