Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,24 @@ import com.google.firebase.vertexai.type.ResponseStoppedException
import com.google.firebase.vertexai.type.SerializationException
import com.google.firebase.vertexai.type.ServerException
import com.google.firebase.vertexai.type.TextPart
import com.google.firebase.vertexai.util.goldenStreamingFile
import com.google.firebase.vertexai.util.goldenVertexStreamingFile
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.http.HttpStatusCode
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withTimeout
import org.junit.Test

internal class StreamingSnapshotTests {
internal class VertexAIStreamingSnapshotTests {
private val testTimeout = 5.seconds

@Test
fun `short reply`() =
goldenStreamingFile("streaming-success-basic-reply-short.txt") {
goldenVertexStreamingFile("streaming-success-basic-reply-short.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -57,7 +56,7 @@ internal class StreamingSnapshotTests {

@Test
fun `long reply`() =
goldenStreamingFile("streaming-success-basic-reply-long.txt") {
goldenVertexStreamingFile("streaming-success-basic-reply-long.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -73,7 +72,7 @@ internal class StreamingSnapshotTests {

@Test
fun `unknown enum in safety ratings`() =
goldenStreamingFile("streaming-success-unknown-safety-enum.txt") {
goldenVertexStreamingFile("streaming-success-unknown-safety-enum.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -88,7 +87,7 @@ internal class StreamingSnapshotTests {

@Test
fun `unknown enum in finish reason`() =
goldenStreamingFile("streaming-failure-unknown-finish-enum.txt") {
goldenVertexStreamingFile("streaming-failure-unknown-finish-enum.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -99,7 +98,7 @@ internal class StreamingSnapshotTests {

@Test
fun `quotes escaped`() =
goldenStreamingFile("streaming-success-quotes-escaped.txt") {
goldenVertexStreamingFile("streaming-success-quotes-escaped.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -114,7 +113,7 @@ internal class StreamingSnapshotTests {

@Test
fun `prompt blocked for safety`() =
goldenStreamingFile("streaming-failure-prompt-blocked-safety.txt") {
goldenVertexStreamingFile("streaming-failure-prompt-blocked-safety.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -125,7 +124,7 @@ internal class StreamingSnapshotTests {

@Test
fun `prompt blocked for safety with message`() =
goldenStreamingFile("streaming-failure-prompt-blocked-safety-with-message.txt") {
goldenVertexStreamingFile("streaming-failure-prompt-blocked-safety-with-message.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -137,23 +136,26 @@ internal class StreamingSnapshotTests {

@Test
fun `empty content`() =
goldenStreamingFile("streaming-failure-empty-content.txt") {
goldenVertexStreamingFile("streaming-failure-empty-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
}

@Test
fun `http errors`() =
goldenStreamingFile("streaming-failure-http-error.txt", HttpStatusCode.PreconditionFailed) {
goldenVertexStreamingFile(
"streaming-failure-http-error.txt",
HttpStatusCode.PreconditionFailed
) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
}

@Test
fun `stopped for safety`() =
goldenStreamingFile("streaming-failure-finish-reason-safety.txt") {
goldenVertexStreamingFile("streaming-failure-finish-reason-safety.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -164,7 +166,7 @@ internal class StreamingSnapshotTests {

@Test
fun `citation parsed correctly`() =
goldenStreamingFile("streaming-success-citations.txt") {
goldenVertexStreamingFile("streaming-success-citations.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -177,7 +179,7 @@ internal class StreamingSnapshotTests {

@Test
fun `stopped for recitation`() =
goldenStreamingFile("streaming-failure-recitation-no-content.txt") {
goldenVertexStreamingFile("streaming-failure-recitation-no-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
Expand All @@ -188,39 +190,39 @@ internal class StreamingSnapshotTests {

@Test
fun `image rejected`() =
goldenStreamingFile("streaming-failure-image-rejected.txt", HttpStatusCode.BadRequest) {
goldenVertexStreamingFile("streaming-failure-image-rejected.txt", HttpStatusCode.BadRequest) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
}

@Test
fun `unknown model`() =
goldenStreamingFile("streaming-failure-unknown-model.txt", HttpStatusCode.NotFound) {
goldenVertexStreamingFile("streaming-failure-unknown-model.txt", HttpStatusCode.NotFound) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
}

@Test
fun `invalid api key`() =
goldenStreamingFile("streaming-failure-api-key.txt", HttpStatusCode.BadRequest) {
goldenVertexStreamingFile("streaming-failure-api-key.txt", HttpStatusCode.BadRequest) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<InvalidAPIKeyException> { responses.collect() } }
}

@Test
fun `invalid json`() =
goldenStreamingFile("streaming-failure-invalid-json.txt") {
goldenVertexStreamingFile("streaming-failure-invalid-json.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
}

@Test
fun `malformed content`() =
goldenStreamingFile("streaming-failure-malformed-content.txt") {
goldenVertexStreamingFile("streaming-failure-malformed-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
Expand Down
Loading
Loading