Skip to content

Commit 1768919

Browse files
committed
Add support for streaming tests
1 parent e554d0c commit 1768919

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.firebase.ai.type
1818

1919
import android.graphics.Bitmap
20+
import kotlin.collections.filterNot
2021
import kotlinx.serialization.EncodeDefault
2122
import kotlinx.serialization.ExperimentalSerializationApi
2223
import kotlinx.serialization.Serializable
@@ -90,14 +91,21 @@ constructor(public val role: String? = "user", public val parts: List<Part>) {
9091
@Serializable
9192
internal data class Internal(
9293
@EncodeDefault val role: String? = "user",
93-
val parts: List<InternalPart>
94+
val parts: List<InternalPart>? = null
9495
) {
9596
internal fun toPublic(): Content {
97+
// Return empty if none of the parts is a known part
98+
if (parts == null || parts.filterNot { it is UnknownPart.Internal }.isEmpty()) {
99+
return Content(role, emptyList())
100+
}
101+
// From all the known parts, if they are all text and empty, we coalesce them into a single
102+
// one-character string part so the backend doesn't fail if we send this back as part of a
103+
// multi-turn interaction.
96104
val returnedParts =
97-
parts.filterNot { it is UnknownPart.Internal }.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() }
98-
// If all returned parts were text and empty, we coalesce them into a single one-character
99-
// string
100-
// part so the backend doesn't fail if we send this back as part of a multi-turn interaction.
105+
parts
106+
.filterNot { it is UnknownPart.Internal }
107+
.map { it.toPublic() }
108+
.filterNot { it is TextPart && it.text.isEmpty() }
101109
return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) })
102110
}
103111
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,8 @@ internal constructor(
271271
}
272272
}
273273

274-
275-
internal data class UnknownPart(public override val isThought: Boolean = false): Part {
276-
@Serializable
277-
internal data class Internal(val thought: Boolean? = null): InternalPart
274+
internal data class UnknownPart(public override val isThought: Boolean = false) : Part {
275+
@Serializable internal data class Internal(val thought: Boolean? = null) : InternalPart
278276
}
279277

280278
/** Returns the part as a [String] if it represents text, and null otherwise */

firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ import com.google.firebase.ai.type.ResponseStoppedException
2323
import com.google.firebase.ai.type.ServerException
2424
import com.google.firebase.ai.util.goldenDevAPIStreamingFile
2525
import io.kotest.assertions.throwables.shouldThrow
26+
import io.kotest.matchers.collections.shouldBeEmpty
2627
import io.kotest.matchers.shouldBe
2728
import io.ktor.http.HttpStatusCode
2829
import kotlin.time.Duration.Companion.seconds
2930
import kotlinx.coroutines.flow.collect
3031
import kotlinx.coroutines.flow.toList
3132
import kotlinx.coroutines.withTimeout
3233
import org.junit.Test
34+
import org.junit.runner.RunWith
35+
import org.robolectric.RobolectricTestRunner
3336

37+
@RunWith(RobolectricTestRunner::class)
3438
internal class DevAPIStreamingSnapshotTests {
3539
private val testTimeout = 5.seconds
3640

@@ -64,6 +68,23 @@ internal class DevAPIStreamingSnapshotTests {
6468
}
6569
}
6670

71+
@Test
72+
fun `reply with a mostly empty part`() =
73+
goldenDevAPIStreamingFile("streaming-success-empty-parts.txt") {
74+
val responses = model.generateContentStream("prompt")
75+
76+
withTimeout(testTimeout) {
77+
val responseList = responses.toList()
78+
responseList.isEmpty() shouldBe false
79+
// Second to last response has no parts
80+
responseList[5].candidates.first().content.parts.shouldBeEmpty()
81+
responseList.last().candidates.first().apply {
82+
finishReason shouldBe FinishReason.STOP
83+
content.parts.isEmpty() shouldBe false
84+
}
85+
}
86+
}
87+
6788
@Test
6889
fun `prompt blocked for safety`() =
6990
goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") {

0 commit comments

Comments
 (0)