Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions firebase-ai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
- [feature] Added helper functions to `LiveSession` to allow developers to track the status of the
audio session and the underlying websocket connection.
- [changed] Added new values to `HarmCategory` (#7324)
- [fixed] Fixed an issue that caused unknown or empty `Part`s to throw an exception. Instead, we now
log them and filter them from the response (#7333)

# 17.2.0

Expand Down
18 changes: 13 additions & 5 deletions firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.firebase.ai.type

import android.graphics.Bitmap
import kotlin.collections.filterNot
import kotlinx.serialization.EncodeDefault
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable
Expand Down Expand Up @@ -90,14 +91,21 @@ constructor(public val role: String? = "user", public val parts: List<Part>) {
@Serializable
internal data class Internal(
@EncodeDefault val role: String? = "user",
val parts: List<InternalPart>
val parts: List<InternalPart>? = null
) {
internal fun toPublic(): Content {
// Return empty if none of the parts is a known part
if (parts == null || parts.filterNot { it is UnknownPart.Internal }.isEmpty()) {
return Content(role, emptyList())
}
// From all the known parts, if they are all text and empty, we coalesce them into a single
// one-character string part so the backend doesn't fail if we send this back as part of a
// multi-turn interaction.
val returnedParts =
parts.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() }
// If all returned parts were text and empty, we coalesce them into a single one-character
// string
// part so the backend doesn't fail if we send this back as part of a multi-turn interaction.
parts
.filterNot { it is UnknownPart.Internal }
.map { it.toPublic() }
.filterNot { it is TextPart && it.text.isEmpty() }
return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) })
}
}
Expand Down
14 changes: 13 additions & 1 deletion firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.firebase.ai.type

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.util.Log
import java.io.ByteArrayOutputStream
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerialName
Expand Down Expand Up @@ -270,6 +271,10 @@ internal constructor(
}
}

internal data class UnknownPart(public override val isThought: Boolean = false) : Part {
@Serializable internal data class Internal(val thought: Boolean? = null) : InternalPart
}

/** Returns the part as a [String] if it represents text, and null otherwise */
public fun Part.asTextOrNull(): String? = (this as? TextPart)?.text

Expand All @@ -290,6 +295,9 @@ internal const val BASE_64_FLAGS = android.util.Base64.NO_WRAP

internal object PartSerializer :
JsonContentPolymorphicSerializer<InternalPart>(InternalPart::class) {

private val TAG = PartSerializer::javaClass.name

override fun selectDeserializer(element: JsonElement): DeserializationStrategy<InternalPart> {
val jsonObject = element.jsonObject
return when {
Expand All @@ -300,7 +308,10 @@ internal object PartSerializer :
"functionResponse" in jsonObject -> FunctionResponsePart.Internal.serializer()
"inlineData" in jsonObject -> InlineDataPart.Internal.serializer()
"fileData" in jsonObject -> FileDataPart.Internal.serializer()
else -> throw SerializationException("Unknown Part type")
else -> {
Log.w(TAG, "Unknown part type received, ignoring.")
UnknownPart.Internal.serializer()
}
}
}
}
Expand Down Expand Up @@ -410,6 +421,7 @@ internal fun InternalPart.toPublic(): Part {
thought ?: false,
thoughtSignature
)
is UnknownPart.Internal -> UnknownPart()
else ->
throw com.google.firebase.ai.type.SerializationException(
"Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ import com.google.firebase.ai.type.ResponseStoppedException
import com.google.firebase.ai.type.ServerException
import com.google.firebase.ai.util.goldenDevAPIStreamingFile
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.collections.shouldBeEmpty
import io.kotest.matchers.shouldBe
import io.ktor.http.HttpStatusCode
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withTimeout
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner

@RunWith(RobolectricTestRunner::class)
internal class DevAPIStreamingSnapshotTests {
private val testTimeout = 5.seconds

Expand Down Expand Up @@ -64,6 +68,23 @@ internal class DevAPIStreamingSnapshotTests {
}
}

@Test
fun `reply with a single empty part`() =
goldenDevAPIStreamingFile("streaming-success-empty-parts.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.isEmpty() shouldBe false
// Second to last response has no parts
responseList[5].candidates.first().content.parts.shouldBeEmpty()
responseList.last().candidates.first().apply {
finishReason shouldBe FinishReason.STOP
content.parts.isEmpty() shouldBe false
}
}
}

@Test
fun `prompt blocked for safety`() =
goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import com.google.firebase.ai.type.ServerException
import com.google.firebase.ai.type.TextPart
import com.google.firebase.ai.util.goldenVertexStreamingFile
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.collections.shouldBeEmpty
import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
Expand Down Expand Up @@ -155,7 +157,13 @@ internal class VertexAIStreamingSnapshotTests {
goldenVertexStreamingFile("streaming-failure-empty-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
withTimeout(testTimeout) {
withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.shouldHaveSize(1)
responseList.first().candidates.first().content.parts.shouldBeEmpty()
}
}
}

@Test
Expand Down Expand Up @@ -241,6 +249,10 @@ internal class VertexAIStreamingSnapshotTests {
goldenVertexStreamingFile("streaming-failure-malformed-content.txt") {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.shouldHaveSize(1)
responseList.first().candidates.first().content.parts.shouldBeEmpty()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import com.google.firebase.ai.util.goldenVertexUnaryFile
import com.google.firebase.ai.util.shouldNotBeNullOrEmpty
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.inspectors.forAtLeastOne
import io.kotest.matchers.collections.shouldBeEmpty
import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.nulls.shouldBeNull
import io.kotest.matchers.nulls.shouldNotBeNull
Expand Down Expand Up @@ -90,6 +91,19 @@ internal class VertexAIUnarySnapshotTests {
}
}

@Test
fun `response including an empty part is handled gracefully`() =
goldenVertexUnaryFile("unary-success-empty-part.json") {
withTimeout(testTimeout) {
val response = model.generateContent("prompt")

response.candidates.isEmpty() shouldBe false
response.text.shouldNotBeEmpty()
response.candidates.first().finishReason shouldBe FinishReason.STOP
response.candidates.first().content.parts.isEmpty() shouldBe false
}
}

@Test
fun `response with detailed token-based usageMetadata`() =
goldenVertexUnaryFile("unary-success-basic-response-long-usage-metadata.json") {
Expand Down Expand Up @@ -246,7 +260,9 @@ internal class VertexAIUnarySnapshotTests {
fun `empty content`() =
goldenVertexUnaryFile("unary-failure-empty-content.json") {
withTimeout(testTimeout) {
shouldThrow<SerializationException> { model.generateContent("prompt") }
val response = model.generateContent("prompt")
response.candidates.shouldNotBeEmpty()
response.candidates.first().content.parts.shouldBeEmpty()
}
}

Expand Down Expand Up @@ -389,10 +405,12 @@ internal class VertexAIUnarySnapshotTests {
}

@Test
fun `malformed content`() =
fun `response including an unknown part is handled gracefully`() =
goldenVertexUnaryFile("unary-failure-malformed-content.json") {
withTimeout(testTimeout) {
shouldThrow<SerializationException> { model.generateContent("prompt") }
val response = model.generateContent("prompt")
response.candidates.shouldNotBeEmpty()
response.candidates.first().content.parts.shouldBeEmpty()
}
}

Expand Down