Skip to content

Commit 4f44ef2

Browse files
committed
Add App Check tests in GenerativeModelTests (#12590)
1 parent 23a1feb commit 4f44ef2

File tree

1 file changed

+129
-2
lines changed

1 file changed

+129
-2
lines changed

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import FirebaseAppCheckInterop
1516
import XCTest
1617

1718
@testable import FirebaseVertexAI
@@ -178,6 +179,43 @@ final class GenerativeModelTests: XCTestCase {
178179
_ = try await model.generateContent(testPrompt)
179180
}
180181

182+
func testGenerateContent_appCheck_validToken() async throws {
183+
let appCheckToken = "test-valid-token"
184+
model = GenerativeModel(
185+
name: "my-model",
186+
apiKey: "API_KEY",
187+
requestOptions: RequestOptions(),
188+
appCheck: AppCheckInteropFake(token: appCheckToken),
189+
urlSession: urlSession
190+
)
191+
MockURLProtocol
192+
.requestHandler = try httpRequestHandler(
193+
forResource: "unary-success-basic-reply-short",
194+
withExtension: "json",
195+
appCheckToken: appCheckToken
196+
)
197+
198+
_ = try await model.generateContent(testPrompt)
199+
}
200+
201+
func testGenerateContent_appCheck_tokenRefreshError() async throws {
202+
model = GenerativeModel(
203+
name: "my-model",
204+
apiKey: "API_KEY",
205+
requestOptions: RequestOptions(),
206+
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
207+
urlSession: urlSession
208+
)
209+
MockURLProtocol
210+
.requestHandler = try httpRequestHandler(
211+
forResource: "unary-success-basic-reply-short",
212+
withExtension: "json",
213+
appCheckToken: AppCheckInteropFake.placeholderTokenValue
214+
)
215+
216+
_ = try await model.generateContent(testPrompt)
217+
}
218+
181219
func testGenerateContent_failure_invalidAPIKey() async throws {
182220
let expectedStatusCode = 400
183221
MockURLProtocol
@@ -654,6 +692,45 @@ final class GenerativeModelTests: XCTestCase {
654692
.contains(where: { $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty }))
655693
}
656694

695+
func testGenerateContentStream_appCheck_validToken() async throws {
696+
let appCheckToken = "test-valid-token"
697+
model = GenerativeModel(
698+
name: "my-model",
699+
apiKey: "API_KEY",
700+
requestOptions: RequestOptions(),
701+
appCheck: AppCheckInteropFake(token: appCheckToken),
702+
urlSession: urlSession
703+
)
704+
MockURLProtocol
705+
.requestHandler = try httpRequestHandler(
706+
forResource: "streaming-success-basic-reply-short",
707+
withExtension: "txt",
708+
appCheckToken: appCheckToken
709+
)
710+
711+
let stream = model.generateContentStream(testPrompt)
712+
for try await _ in stream {}
713+
}
714+
715+
func testGenerateContentStream_appCheck_tokenRefreshError() async throws {
716+
model = GenerativeModel(
717+
name: "my-model",
718+
apiKey: "API_KEY",
719+
requestOptions: RequestOptions(),
720+
appCheck: AppCheckInteropFake(error: AppCheckErrorFake()),
721+
urlSession: urlSession
722+
)
723+
MockURLProtocol
724+
.requestHandler = try httpRequestHandler(
725+
forResource: "streaming-success-basic-reply-short",
726+
withExtension: "txt",
727+
appCheckToken: AppCheckInteropFake.placeholderTokenValue
728+
)
729+
730+
let stream = model.generateContentStream(testPrompt)
731+
for try await _ in stream {}
732+
}
733+
657734
func testGenerateContentStream_errorMidStream() async throws {
658735
MockURLProtocol.requestHandler = try httpRequestHandler(
659736
forResource: "streaming-failure-error-mid-stream",
@@ -887,8 +964,8 @@ final class GenerativeModelTests: XCTestCase {
887964
private func httpRequestHandler(forResource name: String,
888965
withExtension ext: String,
889966
statusCode: Int = 200,
890-
timeout: TimeInterval = URLRequest
891-
.defaultTimeoutInterval()) throws -> ((URLRequest) throws -> (
967+
timeout: TimeInterval = URLRequest.defaultTimeoutInterval(),
968+
appCheckToken: String? = nil) throws -> ((URLRequest) throws -> (
892969
URLResponse,
893970
AsyncLineSequence<URL.AsyncBytes>?
894971
)) {
@@ -897,6 +974,7 @@ final class GenerativeModelTests: XCTestCase {
897974
let requestURL = try XCTUnwrap(request.url)
898975
XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
899976
XCTAssertEqual(request.timeoutInterval, timeout)
977+
XCTAssertEqual(request.value(forHTTPHeaderField: "X-Firebase-AppCheck"), appCheckToken)
900978
let response = try XCTUnwrap(HTTPURLResponse(
901979
url: requestURL,
902980
statusCode: statusCode,
@@ -922,3 +1000,52 @@ private extension URLRequest {
9221000
return URLRequest(url: placeholderURL).timeoutInterval
9231001
}
9241002
}
1003+
1004+
class AppCheckInteropFake: NSObject, AppCheckInterop {
1005+
/// The placeholder token value returned when an error occurs
1006+
static let placeholderTokenValue = "placeholder-token"
1007+
1008+
var token: String
1009+
var error: Error?
1010+
1011+
private init(token: String, error: Error?) {
1012+
self.token = token
1013+
self.error = error
1014+
}
1015+
1016+
convenience init(token: String) {
1017+
self.init(token: token, error: nil)
1018+
}
1019+
1020+
convenience init(error: Error) {
1021+
self.init(token: AppCheckInteropFake.placeholderTokenValue, error: error)
1022+
}
1023+
1024+
func getToken(forcingRefresh: Bool) async -> any FIRAppCheckTokenResultInterop {
1025+
return AppCheckTokenResultInteropFake(token: token, error: error)
1026+
}
1027+
1028+
func tokenDidChangeNotificationName() -> String {
1029+
fatalError("\(#function) not implemented.")
1030+
}
1031+
1032+
func notificationTokenKey() -> String {
1033+
fatalError("\(#function) not implemented.")
1034+
}
1035+
1036+
func notificationAppNameKey() -> String {
1037+
fatalError("\(#function) not implemented.")
1038+
}
1039+
1040+
private class AppCheckTokenResultInteropFake: NSObject, FIRAppCheckTokenResultInterop {
1041+
var token: String
1042+
var error: Error?
1043+
1044+
init(token: String, error: Error?) {
1045+
self.token = token
1046+
self.error = error
1047+
}
1048+
}
1049+
}
1050+
1051+
struct AppCheckErrorFake: Error {}

0 commit comments

Comments
 (0)