diff --git a/firebase-vertexai/CHANGELOG.md b/firebase-vertexai/CHANGELOG.md index 863b520831f..62db530d71f 100644 --- a/firebase-vertexai/CHANGELOG.md +++ b/firebase-vertexai/CHANGELOG.md @@ -1,6 +1,8 @@ # Unreleased +* [changed] Added new exception type for quota exceeded scenarios. * [feature] `CountTokenRequest` now includes `GenerationConfig` from the model. + # 16.2.0 * [fixed] Added support for new values sent by the server for `FinishReason` and `BlockReason`. * [changed] Added support for modality-based token count. (#6658) diff --git a/firebase-vertexai/api.txt b/firebase-vertexai/api.txt index ecf5ab8eefc..76491378d88 100644 --- a/firebase-vertexai/api.txt +++ b/firebase-vertexai/api.txt @@ -557,6 +557,9 @@ package com.google.firebase.vertexai.type { @kotlin.RequiresOptIn(level=kotlin.RequiresOptIn.Level.ERROR, message="This API is part of an experimental public preview and may change in " + "backwards-incompatible ways without notice.") @kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.BINARY) public @interface PublicPreviewAPI { } + public final class QuotaExceededException extends com.google.firebase.vertexai.type.FirebaseVertexAIException { + } + public final class RequestOptions { ctor public RequestOptions(); ctor public RequestOptions(long timeoutInMillis = 180.seconds.inWholeMilliseconds); diff --git a/firebase-vertexai/gradle.properties b/firebase-vertexai/gradle.properties index 546c015493e..c0a96853e52 100644 --- a/firebase-vertexai/gradle.properties +++ b/firebase-vertexai/gradle.properties @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -version=16.2.1 +version=16.3.0 latestReleasedVersion=16.2.0 diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt index 4890cd7ada3..4a29e5c37ea 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt @@ -59,6 +59,8 @@ internal constructor(message: String, cause: Throwable? = null) : RuntimeExcepti UnknownException(cause.message ?: "", cause.cause) is com.google.firebase.vertexai.common.ContentBlockedException -> ContentBlockedException(cause.message ?: "", cause.cause) + is com.google.firebase.vertexai.common.QuotaExceededException -> + QuotaExceededException(cause.message ?: "", cause.cause) else -> UnknownException(cause.message ?: "", cause) } is TimeoutCancellationException -> @@ -165,6 +167,14 @@ public class ServiceDisabledException internal constructor(message: String, cause: Throwable? = null) : FirebaseVertexAIException(message, cause) +/** + * The request has hit a quota limit. Learn more about quotas in the + * [Firebase documentation.](https://firebase.google.com/docs/vertex-ai/quotas) + */ +public class QuotaExceededException +internal constructor(message: String, cause: Throwable? = null) : + FirebaseVertexAIException(message, cause) + /** Catch all case for exceptions not explicitly expected. */ public class UnknownException internal constructor(message: String, cause: Throwable? = null) : FirebaseVertexAIException(message, cause) diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt index 1724b3788cb..a7ed1c4ed7f 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt @@ -27,6 +27,7 @@ import com.google.firebase.vertexai.type.HarmSeverity import com.google.firebase.vertexai.type.InvalidAPIKeyException import com.google.firebase.vertexai.type.PromptBlockedException import com.google.firebase.vertexai.type.PublicPreviewAPI +import com.google.firebase.vertexai.type.QuotaExceededException import com.google.firebase.vertexai.type.ResponseStoppedException import com.google.firebase.vertexai.type.SerializationException import com.google.firebase.vertexai.type.ServerException @@ -72,6 +73,19 @@ internal class UnarySnapshotTests { } } + @Test + fun `long reply`() = + goldenUnaryFile("unary-success-basic-reply-long.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content.parts.isEmpty() shouldBe false + response.candidates.first().safetyRatings.isEmpty() shouldBe false + } + } + @Test fun `response with detailed token-based usageMetadata`() = goldenUnaryFile("unary-success-basic-response-long-usage-metadata.json") { @@ -177,6 +191,20 @@ internal class UnarySnapshotTests { } } + @Test + fun `function call has no arguments field`() = + goldenUnaryFile("unary-success-function-call-empty-arguments.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val content = response.candidates.shouldNotBeNullOrEmpty().first().content + content.shouldNotBeNull() + val callPart = content.parts.shouldNotBeNullOrEmpty().first() as FunctionCallPart + + callPart.name shouldBe "current_time" + callPart.args shouldBe emptyMap() + } + } + @Test fun `prompt blocked for safety`() = goldenUnaryFile("unary-failure-prompt-blocked-safety.json") { @@ -239,6 +267,14 @@ internal class UnarySnapshotTests { } } + @Test + fun `quota exceeded`() = + goldenUnaryFile("unary-failure-quota-exceeded.json", HttpStatusCode.BadRequest) { + withTimeout(testTimeout) { + shouldThrow { model.generateContent("prompt") } + } + } + @Test fun `stopped for safety with no content`() = goldenUnaryFile("unary-failure-finish-reason-safety-no-content.json") { diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt deleted file mode 100644 index 8b421edfa50..00000000000 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common - -import com.google.firebase.vertexai.common.util.goldenStreamingFile -import com.google.firebase.vertexai.type.BlockReason -import com.google.firebase.vertexai.type.FinishReason -import com.google.firebase.vertexai.type.HarmCategory -import com.google.firebase.vertexai.type.TextPart -import io.kotest.assertions.throwables.shouldThrow -import io.kotest.matchers.nulls.shouldNotBeNull -import io.kotest.matchers.shouldBe -import io.kotest.matchers.string.shouldContain -import io.ktor.http.HttpStatusCode -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.flow.collect -import kotlinx.coroutines.flow.toList -import kotlinx.coroutines.withTimeout -import kotlinx.serialization.ExperimentalSerializationApi -import org.junit.Test - -@OptIn(ExperimentalSerializationApi::class) -internal class StreamingSnapshotTests { - private val testTimeout = 5.seconds - - @Test - fun `short reply`() = - goldenStreamingFile("success-basic-reply-short.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val responseList = responses.toList() - responseList.isEmpty() shouldBe false - responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP - responseList.first().candidates?.first()?.content?.parts?.isEmpty() shouldBe false - responseList.first().candidates?.first()?.safetyRatings?.isEmpty() shouldBe false - } - } - - @Test - fun `long reply`() = - goldenStreamingFile("success-basic-reply-long.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val responseList = responses.toList() - responseList.isEmpty() shouldBe false - responseList.forEach { - it.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP - it.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - it.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false - } - } - } - - @Test - fun `unknown enum`() = - goldenStreamingFile("success-unknown-safety-enum.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val responseList = responses.toList() - responseList.isEmpty() shouldBe false - responseList.any { - it.candidates?.any { - it.safetyRatings?.any { it.category == HarmCategory.Internal.UNKNOWN } ?: false - } - ?: false - } shouldBe true - } - } - - @Test - fun `quotes escaped`() = - goldenStreamingFile("success-quotes-escaped.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val responseList = responses.toList() - - responseList.isEmpty() shouldBe false - val part = - responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart.Internal - part.shouldNotBeNull() - part.text shouldContain "\"" - } - } - - @Test - fun `prompt blocked for safety`() = - goldenStreamingFile("failure-prompt-blocked-safety.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val exception = shouldThrow { responses.collect() } - exception.response?.promptFeedback?.blockReason shouldBe BlockReason.Internal.SAFETY - } - } - - @Test - fun `empty content`() = - goldenStreamingFile("failure-empty-content.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { shouldThrow { responses.collect() } } - } - - @Test - fun `http errors`() = - goldenStreamingFile("failure-http-error.txt", HttpStatusCode.PreconditionFailed) { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { shouldThrow { responses.collect() } } - } - - @Test - fun `stopped for safety`() = - goldenStreamingFile("failure-finish-reason-safety.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.SAFETY - } - } - - @Test - fun `citation parsed correctly`() = - goldenStreamingFile("success-citations.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val responseList = responses.toList() - responseList.any { - it.candidates?.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } - ?: false - } shouldBe true - } - } - - @Test - fun `stopped for recitation`() = - goldenStreamingFile("failure-recitation-no-content.txt") { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { - val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe - FinishReason.Internal.RECITATION - } - } - - @Test - fun `image rejected`() = - goldenStreamingFile("failure-image-rejected.txt", HttpStatusCode.BadRequest) { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { shouldThrow { responses.collect() } } - } - - @Test - fun `unknown model`() = - goldenStreamingFile("failure-unknown-model.txt", HttpStatusCode.NotFound) { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { shouldThrow { responses.collect() } } - } - - @Test - fun `invalid api key`() = - goldenStreamingFile("failure-api-key.txt", HttpStatusCode.BadRequest) { - val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) - - withTimeout(testTimeout) { shouldThrow { responses.collect() } } - } -} diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt deleted file mode 100644 index 49a24201c3f..00000000000 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt +++ /dev/null @@ -1,353 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common - -import com.google.firebase.vertexai.common.util.goldenUnaryFile -import com.google.firebase.vertexai.common.util.shouldNotBeNullOrEmpty -import com.google.firebase.vertexai.type.BlockReason -import com.google.firebase.vertexai.type.FinishReason -import com.google.firebase.vertexai.type.FunctionCallPart -import com.google.firebase.vertexai.type.HarmCategory -import com.google.firebase.vertexai.type.HarmProbability -import com.google.firebase.vertexai.type.HarmSeverity -import com.google.firebase.vertexai.type.TextPart -import io.kotest.assertions.throwables.shouldThrow -import io.kotest.matchers.collections.shouldNotBeEmpty -import io.kotest.matchers.nulls.shouldNotBeNull -import io.kotest.matchers.should -import io.kotest.matchers.shouldBe -import io.kotest.matchers.shouldNotBe -import io.kotest.matchers.types.shouldBeInstanceOf -import io.ktor.http.HttpStatusCode -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.withTimeout -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.JsonPrimitive -import org.junit.Test - -@Serializable internal data class MountainColors(val name: String, val colors: List) - -internal class UnarySnapshotTests { - private val testTimeout = 5.seconds - - @OptIn(ExperimentalSerializationApi::class) - @Test - fun `short reply`() = - goldenUnaryFile("success-basic-reply-short.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP - response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false - } - } - - @OptIn(ExperimentalSerializationApi::class) - @Test - fun `long reply`() = - goldenUnaryFile("success-basic-reply-long.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP - response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false - } - } - - @Test - fun `unknown enum`() = - goldenUnaryFile("success-unknown-enum-safety-ratings.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isNullOrEmpty() shouldBe false - val candidate = response.candidates?.first() - candidate?.safetyRatings?.any { it.category == HarmCategory.Internal.UNKNOWN } shouldBe true - } - } - - @Test - fun `safetyRatings including severity`() = - goldenUnaryFile("success-including-severity.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false - response.candidates?.first()?.safetyRatings?.all { - it.probability == HarmProbability.Internal.NEGLIGIBLE - } shouldBe true - response.candidates?.first()?.safetyRatings?.all { it.probabilityScore != null } shouldBe - true - response.candidates?.first()?.safetyRatings?.all { - it.severity == HarmSeverity.Internal.NEGLIGIBLE - } shouldBe true - response.candidates?.first()?.safetyRatings?.all { it.severityScore != null } shouldBe true - } - } - - @Test - fun `prompt blocked for safety`() = - goldenUnaryFile("failure-prompt-blocked-safety.json") { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } should { it.response?.promptFeedback?.blockReason shouldBe BlockReason.Internal.SAFETY } - } - } - - @Test - fun `empty content`() = - goldenUnaryFile("failure-empty-content.json") { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `http error`() = - goldenUnaryFile("failure-http-error.json", HttpStatusCode.PreconditionFailed) { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `user location error`() = - goldenUnaryFile("failure-unsupported-user-location.json", HttpStatusCode.PreconditionFailed) { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `stopped for safety`() = - goldenUnaryFile("failure-finish-reason-safety.json") { - withTimeout(testTimeout) { - val exception = - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.SAFETY - } - } - - @Test - fun `citation returns correctly`() = - goldenUnaryFile("success-citations.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true - } - } - - @Test - fun `citation returns correctly with missing license and startIndex`() = - goldenUnaryFile("success-citations-nolicense.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true - // Verify the values in the citation source - with(response.candidates?.first()?.citationMetadata?.citationSources?.first()!!) { - license shouldBe null - startIndex shouldBe 0 - } - } - } - - @Test - fun `response includes usage metadata`() = - goldenUnaryFile("success-usage-metadata.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP - response.usageMetadata shouldNotBe null - response.usageMetadata?.totalTokenCount shouldBe 363 - } - } - - @Test - fun `response includes partial usage metadata`() = - goldenUnaryFile("success-partial-usage-metadata.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP - response.usageMetadata shouldNotBe null - response.usageMetadata?.promptTokenCount shouldBe 6 - response.usageMetadata?.totalTokenCount shouldBe null - } - } - - @OptIn(ExperimentalSerializationApi::class) - @Test - fun `properly translates json text`() = - goldenUnaryFile("success-constraint-decoding-json.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - - response.candidates?.isEmpty() shouldBe false - with( - response.candidates - ?.first() - ?.content - ?.parts - ?.first() - ?.shouldBeInstanceOf() - ) { - shouldNotBeNull() - JSON.decodeFromString>(text).shouldNotBeEmpty() - } - } - } - - @Test - fun `invalid response`() = - goldenUnaryFile("failure-invalid-response.json") { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `malformed content`() = - goldenUnaryFile("failure-malformed-content.json") { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `invalid api key`() = - goldenUnaryFile("failure-api-key.json", HttpStatusCode.BadRequest) { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `quota exceeded`() = - goldenUnaryFile("failure-quota-exceeded.json", HttpStatusCode.BadRequest) { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `image rejected`() = - goldenUnaryFile("failure-image-rejected.json", HttpStatusCode.BadRequest) { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `unknown model`() = - goldenUnaryFile("failure-unknown-model.json", HttpStatusCode.NotFound) { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @Test - fun `service disabled`() = - goldenUnaryFile("failure-firebaseml-api-not-enabled.json", HttpStatusCode.Forbidden) { - withTimeout(testTimeout) { - shouldThrow { - apiController.generateContent(textGenerateContentRequest("prompt")) - } - } - } - - @OptIn(ExperimentalSerializationApi::class) - @Test - fun `function call contains null param`() = - goldenUnaryFile("success-function-call-null.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val callPart = - (response.candidates!!.first().content!!.parts.first() as FunctionCallPart.Internal) - - callPart.functionCall.args shouldNotBe null - callPart.functionCall.args?.get("season") shouldBe null - } - } - - @OptIn(ExperimentalSerializationApi::class) - @Test - fun `function call contains json literal`() = - goldenUnaryFile("success-function-call-json-literal.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val content = response.candidates.shouldNotBeNullOrEmpty().first().content - val callPart = - content.let { - it.shouldNotBeNull() - it.parts.shouldNotBeEmpty() - it.parts.first().shouldBeInstanceOf() - } - - callPart.functionCall.args shouldNotBe null - callPart.functionCall.args?.get("current") shouldBe JsonPrimitive(true) - } - } - - @OptIn(ExperimentalSerializationApi::class) - @Test - fun `function call has no arguments field`() = - goldenUnaryFile("success-function-call-empty-arguments.json") { - withTimeout(testTimeout) { - val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val content = response.candidates.shouldNotBeNullOrEmpty().first().content - content.shouldNotBeNull() - val callPart = content.parts.shouldNotBeNullOrEmpty().first() as FunctionCallPart.Internal - - callPart.functionCall.name shouldBe "current_time" - callPart.functionCall.args shouldBe null - } - } -} diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt index bf79df56604..855c8aa4a8b 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt @@ -19,25 +19,18 @@ package com.google.firebase.vertexai.common.util import com.google.firebase.vertexai.common.APIController -import com.google.firebase.vertexai.common.GenerateContentRequest import com.google.firebase.vertexai.common.JSON import com.google.firebase.vertexai.type.Candidate import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions import com.google.firebase.vertexai.type.TextPart -import io.kotest.matchers.collections.shouldNotBeEmpty -import io.kotest.matchers.nulls.shouldNotBeNull import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode import io.ktor.http.headersOf import io.ktor.utils.io.ByteChannel -import io.ktor.utils.io.close -import io.ktor.utils.io.writeFully -import java.io.File -import kotlinx.coroutines.launch import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString @@ -47,18 +40,6 @@ internal fun prepareStreamingResponse( response: List ): List = response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } -internal fun prepareResponse(response: GenerateContentResponse.Internal) = - JSON.encodeToString(response).toByteArray() - -@OptIn(ExperimentalSerializationApi::class) -internal fun createRequest(vararg text: String): GenerateContentRequest { - val contents = text.map { Content.Internal(parts = listOf(TextPart.Internal(it))) } - - return GenerateContentRequest("gemini", contents) -} - -internal fun createResponse(text: String) = createResponses(text).single() - @OptIn(ExperimentalSerializationApi::class) internal fun createResponses(vararg text: String): List { val candidates = @@ -123,82 +104,3 @@ internal fun commonTest( ) CommonTestScope(channel, apiController).block() } - -/** - * A variant of [commonTest] for performing *streaming-based* snapshot tests. - * - * Loads the *Golden File* and automatically parses the messages from it; providing it to the - * channel. - * - * @param name The name of the *Golden File* to load - * @param httpStatusCode An optional [HttpStatusCode] to return as a response - * @param block The test contents themselves, with a [CommonTestScope] implicitly provided - * @see goldenUnaryFile - */ -internal fun goldenStreamingFile( - name: String, - httpStatusCode: HttpStatusCode = HttpStatusCode.OK, - block: CommonTest, -) = doBlocking { - val goldenFile = loadGoldenFile("streaming-$name") - val messages = goldenFile.readLines().filter { it.isNotBlank() } - - commonTest(httpStatusCode) { - launch { - for (message in messages) { - channel.writeFully("$message$SSE_SEPARATOR".toByteArray()) - } - channel.close() - } - - block() - } -} - -/** - * A variant of [commonTest] for performing snapshot tests. - * - * Loads the *Golden File* and automatically provides it to the channel. - * - * @param name The name of the *Golden File* to load - * @param httpStatusCode An optional [HttpStatusCode] to return as a response - * @param block The test contents themselves, with a [CommonTestScope] implicitly provided - * @see goldenStreamingFile - */ -internal fun goldenUnaryFile( - name: String, - httpStatusCode: HttpStatusCode = HttpStatusCode.OK, - block: CommonTest, -) = - commonTest(httpStatusCode) { - val goldenFile = loadGoldenFile("unary-$name") - val message = goldenFile.readText() - - channel.send(message.toByteArray()) - - block() - } - -/** - * Loads a *Golden File* from the resource directory. - * - * Expects golden files to live under `golden-files` in the resource files. - * - * @see goldenUnaryFile - */ -internal fun loadGoldenFile(path: String): File = - loadResourceFile("vertexai-sdk-test-data/mock-responses/$path") - -/** Loads a file from the test resources directory. */ -internal fun loadResourceFile(path: String) = File("src/test/resources/$path") - -/** - * Ensures that a collection is neither null or empty. - * - * Syntax sugar for [shouldNotBeNull] and [shouldNotBeEmpty]. - */ -inline fun Collection?.shouldNotBeNullOrEmpty(): Collection { - shouldNotBeNull() - shouldNotBeEmpty() - return this -}