diff --git a/firebase-ai/CHANGELOG.md b/firebase-ai/CHANGELOG.md index d63e4781242..4aeef792c44 100644 --- a/firebase-ai/CHANGELOG.md +++ b/firebase-ai/CHANGELOG.md @@ -5,6 +5,8 @@ - [feature] Added helper functions to `LiveSession` to allow developers to track the status of the audio session and the underlying websocket connection. - [changed] Added new values to `HarmCategory` (#7324) +- [fixed] Fixed an issue that caused unknown or empty `Part`s to throw an exception. Instead, we now + log them and filter them from the response (#7333) # 17.2.0 diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt index 4e9f1a860db..350d46e9063 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt @@ -17,6 +17,7 @@ package com.google.firebase.ai.type import android.graphics.Bitmap +import kotlin.collections.filterNot import kotlinx.serialization.EncodeDefault import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable @@ -90,14 +91,21 @@ constructor(public val role: String? = "user", public val parts: List) { @Serializable internal data class Internal( @EncodeDefault val role: String? = "user", - val parts: List + val parts: List? = null ) { internal fun toPublic(): Content { + // Return empty if none of the parts is a known part + if (parts == null || parts.filterNot { it is UnknownPart.Internal }.isEmpty()) { + return Content(role, emptyList()) + } + // From all the known parts, if they are all text and empty, we coalesce them into a single + // one-character string part so the backend doesn't fail if we send this back as part of a + // multi-turn interaction. val returnedParts = - parts.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() } - // If all returned parts were text and empty, we coalesce them into a single one-character - // string - // part so the backend doesn't fail if we send this back as part of a multi-turn interaction. + parts + .filterNot { it is UnknownPart.Internal } + .map { it.toPublic() } + .filterNot { it is TextPart && it.text.isEmpty() } return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) }) } } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt index a8fdfa91fed..881bb90ce68 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt @@ -18,6 +18,7 @@ package com.google.firebase.ai.type import android.graphics.Bitmap import android.graphics.BitmapFactory +import android.util.Log import java.io.ByteArrayOutputStream import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerialName @@ -270,6 +271,10 @@ internal constructor( } } +internal data class UnknownPart(public override val isThought: Boolean = false) : Part { + @Serializable internal data class Internal(val thought: Boolean? = null) : InternalPart +} + /** Returns the part as a [String] if it represents text, and null otherwise */ public fun Part.asTextOrNull(): String? = (this as? TextPart)?.text @@ -290,6 +295,9 @@ internal const val BASE_64_FLAGS = android.util.Base64.NO_WRAP internal object PartSerializer : JsonContentPolymorphicSerializer(InternalPart::class) { + + private val TAG = PartSerializer::javaClass.name + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject return when { @@ -300,7 +308,10 @@ internal object PartSerializer : "functionResponse" in jsonObject -> FunctionResponsePart.Internal.serializer() "inlineData" in jsonObject -> InlineDataPart.Internal.serializer() "fileData" in jsonObject -> FileDataPart.Internal.serializer() - else -> throw SerializationException("Unknown Part type") + else -> { + Log.w(TAG, "Unknown part type received, ignoring.") + UnknownPart.Internal.serializer() + } } } } @@ -410,6 +421,7 @@ internal fun InternalPart.toPublic(): Part { thought ?: false, thoughtSignature ) + is UnknownPart.Internal -> UnknownPart() else -> throw com.google.firebase.ai.type.SerializationException( "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt index 967254a096c..2aac8f7a0d2 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt @@ -23,6 +23,7 @@ import com.google.firebase.ai.type.ResponseStoppedException import com.google.firebase.ai.type.ServerException import com.google.firebase.ai.util.goldenDevAPIStreamingFile import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.collections.shouldBeEmpty import io.kotest.matchers.shouldBe import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds @@ -30,7 +31,10 @@ import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withTimeout import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +@RunWith(RobolectricTestRunner::class) internal class DevAPIStreamingSnapshotTests { private val testTimeout = 5.seconds @@ -64,6 +68,23 @@ internal class DevAPIStreamingSnapshotTests { } } + @Test + fun `reply with a single empty part`() = + goldenDevAPIStreamingFile("streaming-success-empty-parts.txt") { + val responses = model.generateContentStream("prompt") + + withTimeout(testTimeout) { + val responseList = responses.toList() + responseList.isEmpty() shouldBe false + // Second to last response has no parts + responseList[5].candidates.first().content.parts.shouldBeEmpty() + responseList.last().candidates.first().apply { + finishReason shouldBe FinishReason.STOP + content.parts.isEmpty() shouldBe false + } + } + } + @Test fun `prompt blocked for safety`() = goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") { diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt index e6331401fde..d54cda6c7aa 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt @@ -27,6 +27,8 @@ import com.google.firebase.ai.type.ServerException import com.google.firebase.ai.type.TextPart import com.google.firebase.ai.util.goldenVertexStreamingFile import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.collections.shouldBeEmpty +import io.kotest.matchers.collections.shouldHaveSize import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain @@ -155,7 +157,13 @@ internal class VertexAIStreamingSnapshotTests { goldenVertexStreamingFile("streaming-failure-empty-content.txt") { val responses = model.generateContentStream("prompt") - withTimeout(testTimeout) { shouldThrow { responses.collect() } } + withTimeout(testTimeout) { + withTimeout(testTimeout) { + val responseList = responses.toList() + responseList.shouldHaveSize(1) + responseList.first().candidates.first().content.parts.shouldBeEmpty() + } + } } @Test @@ -241,6 +249,10 @@ internal class VertexAIStreamingSnapshotTests { goldenVertexStreamingFile("streaming-failure-malformed-content.txt") { val responses = model.generateContentStream("prompt") - withTimeout(testTimeout) { shouldThrow { responses.collect() } } + withTimeout(testTimeout) { + val responseList = responses.toList() + responseList.shouldHaveSize(1) + responseList.first().candidates.first().content.parts.shouldBeEmpty() + } } } diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt index e34a1e1db0d..ddcd5a7d6ea 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt @@ -38,6 +38,7 @@ import com.google.firebase.ai.util.goldenVertexUnaryFile import com.google.firebase.ai.util.shouldNotBeNullOrEmpty import io.kotest.assertions.throwables.shouldThrow import io.kotest.inspectors.forAtLeastOne +import io.kotest.matchers.collections.shouldBeEmpty import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldBeNull import io.kotest.matchers.nulls.shouldNotBeNull @@ -90,6 +91,19 @@ internal class VertexAIUnarySnapshotTests { } } + @Test + fun `response including an empty part is handled gracefully`() = + goldenVertexUnaryFile("unary-success-empty-part.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + + response.candidates.isEmpty() shouldBe false + response.text.shouldNotBeEmpty() + response.candidates.first().finishReason shouldBe FinishReason.STOP + response.candidates.first().content.parts.isEmpty() shouldBe false + } + } + @Test fun `response with detailed token-based usageMetadata`() = goldenVertexUnaryFile("unary-success-basic-response-long-usage-metadata.json") { @@ -246,7 +260,9 @@ internal class VertexAIUnarySnapshotTests { fun `empty content`() = goldenVertexUnaryFile("unary-failure-empty-content.json") { withTimeout(testTimeout) { - shouldThrow { model.generateContent("prompt") } + val response = model.generateContent("prompt") + response.candidates.shouldNotBeEmpty() + response.candidates.first().content.parts.shouldBeEmpty() } } @@ -389,10 +405,12 @@ internal class VertexAIUnarySnapshotTests { } @Test - fun `malformed content`() = + fun `response including an unknown part is handled gracefully`() = goldenVertexUnaryFile("unary-failure-malformed-content.json") { withTimeout(testTimeout) { - shouldThrow { model.generateContent("prompt") } + val response = model.generateContent("prompt") + response.candidates.shouldNotBeEmpty() + response.candidates.first().content.parts.shouldBeEmpty() } }