Skip to content

Commit fa68048

Browse files
author
David Motsonashvili
committed
add first devapi tests
1 parent 2885e13 commit fa68048

File tree

3 files changed

+285
-5
lines changed

3 files changed

+285
-5
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.vertexai
18+
19+
import com.google.firebase.vertexai.type.BlockReason
20+
import com.google.firebase.vertexai.type.FinishReason
21+
import com.google.firebase.vertexai.type.HarmCategory
22+
import com.google.firebase.vertexai.type.InvalidAPIKeyException
23+
import com.google.firebase.vertexai.type.PromptBlockedException
24+
import com.google.firebase.vertexai.type.ResponseStoppedException
25+
import com.google.firebase.vertexai.type.SerializationException
26+
import com.google.firebase.vertexai.type.ServerException
27+
import com.google.firebase.vertexai.type.TextPart
28+
import com.google.firebase.vertexai.util.goldenDevAPIStreamingFile
29+
import io.kotest.assertions.throwables.shouldThrow
30+
import io.kotest.matchers.nulls.shouldNotBeNull
31+
import io.kotest.matchers.shouldBe
32+
import io.kotest.matchers.string.shouldContain
33+
import io.ktor.http.HttpStatusCode
34+
import kotlin.time.Duration.Companion.seconds
35+
import kotlinx.coroutines.flow.collect
36+
import kotlinx.coroutines.flow.toList
37+
import kotlinx.coroutines.withTimeout
38+
import org.junit.Test
39+
40+
internal class DevAPIStreamingSnapshotTests {
41+
private val testTimeout = 5.seconds
42+
43+
@Test
44+
fun `short reply`() =
45+
goldenDevAPIStreamingFile("streaming-success-basic-reply-short.txt") {
46+
val responses = model.generateContentStream("prompt")
47+
48+
withTimeout(testTimeout) {
49+
val responseList = responses.toList()
50+
responseList.isEmpty() shouldBe false
51+
responseList.first().candidates.first().finishReason shouldBe FinishReason.STOP
52+
responseList.first().candidates.first().content.parts.isEmpty() shouldBe false
53+
responseList.first().candidates.first().safetyRatings.isEmpty() shouldBe false
54+
}
55+
}
56+
57+
@Test
58+
fun `long reply`() =
59+
goldenDevAPIStreamingFile("streaming-success-basic-reply-long.txt") {
60+
val responses = model.generateContentStream("prompt")
61+
62+
withTimeout(testTimeout) {
63+
val responseList = responses.toList()
64+
responseList.isEmpty() shouldBe false
65+
responseList.forEach {
66+
it.candidates.first().finishReason shouldBe FinishReason.STOP
67+
it.candidates.first().content.parts.isEmpty() shouldBe false
68+
it.candidates.first().safetyRatings.isEmpty() shouldBe false
69+
}
70+
}
71+
}
72+
73+
@Test
74+
fun `prompt blocked for safety`() =
75+
goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") {
76+
val responses = model.generateContentStream("prompt")
77+
78+
withTimeout(testTimeout) {
79+
val exception = shouldThrow<PromptBlockedException> { responses.collect() }
80+
exception.response?.promptFeedback?.blockReason shouldBe BlockReason.SAFETY
81+
}
82+
}
83+
84+
@Test
85+
fun `citation parsed correctly`() =
86+
goldenDevAPIStreamingFile("streaming-success-citations.txt") {
87+
val responses = model.generateContentStream("prompt")
88+
89+
withTimeout(testTimeout) {
90+
val responseList = responses.toList()
91+
responseList.any {
92+
it.candidates.any { it.citationMetadata?.citations?.isNotEmpty() ?: false }
93+
} shouldBe true
94+
}
95+
}
96+
97+
@Test
98+
fun `stopped for recitation`() =
99+
goldenDevAPIStreamingFile("streaming-failure-recitation-no-content.txt") {
100+
val responses = model.generateContentStream("prompt")
101+
102+
withTimeout(testTimeout) {
103+
val exception = shouldThrow<ResponseStoppedException> { responses.collect() }
104+
exception.response.candidates.first().finishReason shouldBe FinishReason.RECITATION
105+
}
106+
}
107+
108+
@Test
109+
fun `image rejected`() =
110+
goldenDevAPIStreamingFile("streaming-failure-image-rejected.txt", HttpStatusCode.BadRequest) {
111+
val responses = model.generateContentStream("prompt")
112+
113+
withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
114+
}
115+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.vertexai
18+
19+
import com.google.firebase.vertexai.type.BlockReason
20+
import com.google.firebase.vertexai.type.FinishReason
21+
import com.google.firebase.vertexai.type.InvalidAPIKeyException
22+
import com.google.firebase.vertexai.type.PromptBlockedException
23+
import com.google.firebase.vertexai.type.ResponseStoppedException
24+
import com.google.firebase.vertexai.type.ServerException
25+
import com.google.firebase.vertexai.type.TextPart
26+
import com.google.firebase.vertexai.util.goldenDevAPIUnaryFile
27+
import io.kotest.assertions.throwables.shouldThrow
28+
import io.kotest.inspectors.forAtLeastOne
29+
import io.kotest.matchers.should
30+
import io.kotest.matchers.shouldBe
31+
import io.kotest.matchers.shouldNotBe
32+
import io.kotest.matchers.string.shouldContain
33+
import io.ktor.http.HttpStatusCode
34+
import kotlinx.coroutines.withTimeout
35+
import org.junit.Test
36+
import java.util.Calendar
37+
import kotlin.time.Duration.Companion.seconds
38+
39+
internal class DevAPIUnarySnapshotTests {
40+
private val testTimeout = 5.seconds
41+
42+
@Test
43+
fun `short reply`() =
44+
goldenDevAPIUnaryFile("unary-success-basic-reply-short.txt") {
45+
withTimeout(testTimeout) {
46+
val response = model.generateContent("prompt")
47+
48+
response.candidates.isEmpty() shouldBe false
49+
response.candidates.first().finishReason shouldBe FinishReason.STOP
50+
response.candidates.first().content.parts.isEmpty() shouldBe false
51+
}
52+
}
53+
54+
@Test
55+
fun `long reply`() =
56+
goldenDevAPIUnaryFile("unary-success-basic-reply-long.txt") {
57+
withTimeout(testTimeout) {
58+
val response = model.generateContent("prompt")
59+
60+
response.candidates.isEmpty() shouldBe false
61+
response.candidates.first().finishReason shouldBe FinishReason.STOP
62+
response.candidates.first().content.parts.isEmpty() shouldBe false
63+
}
64+
}
65+
66+
@Test
67+
fun `quotes escaped`() =
68+
goldenDevAPIUnaryFile("unary-success-quote-reply.txt") {
69+
withTimeout(testTimeout) {
70+
val response = model.generateContent("prompt")
71+
72+
response.candidates.isEmpty() shouldBe false
73+
response.candidates.first().content.parts.isEmpty() shouldBe false
74+
val part = response.candidates.first().content.parts.first() as TextPart
75+
part.text shouldContain "\""
76+
}
77+
}
78+
79+
80+
@Test
81+
fun `prompt blocked for safety`() =
82+
goldenDevAPIUnaryFile("unary-failure-prompt-blocked-safety.txt") {
83+
withTimeout(testTimeout) {
84+
shouldThrow<ResponseStoppedException> { model.generateContent("prompt") } should
85+
{
86+
it.response.candidates[0].finishReason shouldBe FinishReason.MAX_TOKENS
87+
}
88+
}
89+
}
90+
91+
@Test
92+
fun `response blocked for safety`() =
93+
goldenDevAPIUnaryFile("unary-failure-finish-reason-safety.txt") {
94+
withTimeout(testTimeout) {
95+
shouldThrow<ResponseStoppedException> { model.generateContent("prompt") } should
96+
{
97+
it.response.candidates[0].finishReason shouldBe FinishReason.MAX_TOKENS
98+
}
99+
}
100+
}
101+
102+
@Test
103+
fun `citation returns correctly`() =
104+
goldenDevAPIUnaryFile("unary-success-citations.txt") {
105+
withTimeout(testTimeout) {
106+
val response = model.generateContent("prompt")
107+
108+
response.candidates.isEmpty() shouldBe false
109+
response.candidates.first().citationMetadata?.citations?.size shouldBe 4
110+
response.candidates.first().citationMetadata?.citations?.forEach {
111+
it.startIndex shouldNotBe null
112+
it.endIndex shouldNotBe null
113+
}
114+
}
115+
}
116+
117+
@Test
118+
fun `invalid api key`() =
119+
goldenDevAPIUnaryFile("unary-failure-api-key.txt", HttpStatusCode.BadRequest) {
120+
withTimeout(testTimeout) {
121+
shouldThrow<InvalidAPIKeyException> { model.generateContent("prompt") }
122+
}
123+
}
124+
@Test
125+
fun `unknown model`() =
126+
goldenDevAPIUnaryFile("unary-failure-unknown-model.txt", HttpStatusCode.NotFound) {
127+
withTimeout(testTimeout) { shouldThrow<ServerException> { model.generateContent("prompt") } }
128+
}
129+
}

firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package com.google.firebase.vertexai.util
2121
import com.google.firebase.vertexai.GenerativeModel
2222
import com.google.firebase.vertexai.ImagenModel
2323
import com.google.firebase.vertexai.common.APIController
24+
import com.google.firebase.vertexai.type.GenerativeBackend
2425
import com.google.firebase.vertexai.type.PublicPreviewAPI
2526
import com.google.firebase.vertexai.type.RequestOptions
2627
import io.kotest.matchers.collections.shouldNotBeEmpty
@@ -97,6 +98,7 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit
9798
internal fun commonTest(
9899
status: HttpStatusCode = HttpStatusCode.OK,
99100
requestOptions: RequestOptions = RequestOptions(),
101+
backend: GenerativeBackend = GenerativeBackend.VERTEX_AI,
100102
block: CommonTest,
101103
) = doBlocking {
102104
val channel = ByteChannel(autoFlush = true)
@@ -111,7 +113,7 @@ internal fun commonTest(
111113
TEST_CLIENT_ID,
112114
null,
113115
)
114-
val model = GenerativeModel("cool-model-name", controller = apiController)
116+
val model = GenerativeModel("cool-model-name", generativeBackend = backend, controller = apiController)
115117
val imagenModel = ImagenModel("cooler-model-name", controller = apiController)
116118
CommonTestScope(channel, model, imagenModel).block()
117119
}
@@ -130,12 +132,13 @@ internal fun commonTest(
130132
internal fun goldenStreamingFile(
131133
name: String,
132134
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
135+
backend: GenerativeBackend = GenerativeBackend.VERTEX_AI,
133136
block: CommonTest,
134137
) = doBlocking {
135138
val goldenFile = loadGoldenFile(name)
136139
val messages = goldenFile.readLines().filter { it.isNotBlank() }
137140

138-
commonTest(httpStatusCode) {
141+
commonTest(httpStatusCode, backend = backend) {
139142
launch {
140143
for (message in messages) {
141144
channel.writeFully("$message$SSE_SEPARATOR".toByteArray())
@@ -162,7 +165,24 @@ internal fun goldenVertexStreamingFile(
162165
name: String,
163166
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
164167
block: CommonTest,
165-
) = goldenStreamingFile("vertexai/$name", httpStatusCode, block)
168+
) = goldenStreamingFile("vertexai/$name", httpStatusCode, block = block)
169+
170+
/**
171+
* A variant of [goldenStreamingFile] for testing the developer api
172+
*
173+
* Loads the *Golden File* and automatically parses the messages from it; providing it to the
174+
* channel.
175+
*
176+
* @param name The name of the *Golden File* to load
177+
* @param httpStatusCode An optional [HttpStatusCode] to return as a response
178+
* @param block The test contents themselves, with a [CommonTestScope] implicitly provided
179+
* @see goldenStreamingFile
180+
*/
181+
internal fun goldenDevAPIStreamingFile(
182+
name: String,
183+
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
184+
block: CommonTest,
185+
) = goldenStreamingFile("vertexai/$name", httpStatusCode, GenerativeBackend.DEVELOPER_API, block)
166186

167187
/**
168188
* A variant of [commonTest] for performing snapshot tests.
@@ -177,9 +197,10 @@ internal fun goldenVertexStreamingFile(
177197
internal fun goldenUnaryFile(
178198
name: String,
179199
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
200+
backend: GenerativeBackend = GenerativeBackend.VERTEX_AI,
180201
block: CommonTest,
181202
) =
182-
commonTest(httpStatusCode) {
203+
commonTest(httpStatusCode, backend = backend) {
183204
val goldenFile = loadGoldenFile(name)
184205
val message = goldenFile.readText()
185206

@@ -201,7 +222,22 @@ internal fun goldenVertexUnaryFile(
201222
name: String,
202223
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
203224
block: CommonTest,
204-
) = goldenUnaryFile("vertexai/$name", httpStatusCode, block)
225+
) = goldenUnaryFile("vertexai/$name", httpStatusCode, block = block)
226+
227+
/**
228+
* A variant of [goldenUnaryFile] for developer api tests Loads the *Golden File* and automatically
229+
* provides it to the channel.
230+
*
231+
* @param name The name of the *Golden File* to load
232+
* @param httpStatusCode An optional [HttpStatusCode] to return as a response
233+
* @param block The test contents themselves, with a [CommonTestScope] implicitly provided
234+
* @see goldenUnaryFile
235+
*/
236+
internal fun goldenDevAPIUnaryFile(
237+
name: String,
238+
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
239+
block: CommonTest,
240+
) = goldenUnaryFile("developerapi/$name", httpStatusCode, GenerativeBackend.DEVELOPER_API, block)
205241

206242
/**
207243
* Loads a *Golden File* from the resource directory.

0 commit comments

Comments
 (0)