Skip to content

Commit b986d2f

Browse files
authored
[AI] Ignore, and log, unknown parts. (#7333)
The previoius approach, throwing an exception, was based on the assumption that unknown parts would be important parts of the response. That would mean that new parts would force devs to update to SDK versions supporting them before using them. Actually, some times the parts returned by the model can be safely ignored. This change modifies the behaviour to Log, as a warning, and ignore, rather than fail, if an unknown part is found when parsing the response. Internal b/441783876
1 parent 5c7a993 commit b986d2f

File tree

6 files changed

+84
-11
lines changed

6 files changed

+84
-11
lines changed

firebase-ai/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
- [feature] Added helper functions to `LiveSession` to allow developers to track the status of the
66
audio session and the underlying websocket connection.
77
- [changed] Added new values to `HarmCategory` (#7324)
8+
- [fixed] Fixed an issue that caused unknown or empty `Part`s to throw an exception. Instead, we now
9+
log them and filter them from the response (#7333)
810

911
# 17.2.0
1012

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.firebase.ai.type
1818

1919
import android.graphics.Bitmap
20+
import kotlin.collections.filterNot
2021
import kotlinx.serialization.EncodeDefault
2122
import kotlinx.serialization.ExperimentalSerializationApi
2223
import kotlinx.serialization.Serializable
@@ -90,14 +91,21 @@ constructor(public val role: String? = "user", public val parts: List<Part>) {
9091
@Serializable
9192
internal data class Internal(
9293
@EncodeDefault val role: String? = "user",
93-
val parts: List<InternalPart>
94+
val parts: List<InternalPart>? = null
9495
) {
9596
internal fun toPublic(): Content {
97+
// Return empty if none of the parts is a known part
98+
if (parts == null || parts.filterNot { it is UnknownPart.Internal }.isEmpty()) {
99+
return Content(role, emptyList())
100+
}
101+
// From all the known parts, if they are all text and empty, we coalesce them into a single
102+
// one-character string part so the backend doesn't fail if we send this back as part of a
103+
// multi-turn interaction.
96104
val returnedParts =
97-
parts.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() }
98-
// If all returned parts were text and empty, we coalesce them into a single one-character
99-
// string
100-
// part so the backend doesn't fail if we send this back as part of a multi-turn interaction.
105+
parts
106+
.filterNot { it is UnknownPart.Internal }
107+
.map { it.toPublic() }
108+
.filterNot { it is TextPart && it.text.isEmpty() }
101109
return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) })
102110
}
103111
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package com.google.firebase.ai.type
1818

1919
import android.graphics.Bitmap
2020
import android.graphics.BitmapFactory
21+
import android.util.Log
2122
import java.io.ByteArrayOutputStream
2223
import kotlinx.serialization.DeserializationStrategy
2324
import kotlinx.serialization.SerialName
@@ -270,6 +271,10 @@ internal constructor(
270271
}
271272
}
272273

274+
internal data class UnknownPart(public override val isThought: Boolean = false) : Part {
275+
@Serializable internal data class Internal(val thought: Boolean? = null) : InternalPart
276+
}
277+
273278
/** Returns the part as a [String] if it represents text, and null otherwise */
274279
public fun Part.asTextOrNull(): String? = (this as? TextPart)?.text
275280

@@ -290,6 +295,9 @@ internal const val BASE_64_FLAGS = android.util.Base64.NO_WRAP
290295

291296
internal object PartSerializer :
292297
JsonContentPolymorphicSerializer<InternalPart>(InternalPart::class) {
298+
299+
private val TAG = PartSerializer::javaClass.name
300+
293301
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<InternalPart> {
294302
val jsonObject = element.jsonObject
295303
return when {
@@ -300,7 +308,10 @@ internal object PartSerializer :
300308
"functionResponse" in jsonObject -> FunctionResponsePart.Internal.serializer()
301309
"inlineData" in jsonObject -> InlineDataPart.Internal.serializer()
302310
"fileData" in jsonObject -> FileDataPart.Internal.serializer()
303-
else -> throw SerializationException("Unknown Part type")
311+
else -> {
312+
Log.w(TAG, "Unknown part type received, ignoring.")
313+
UnknownPart.Internal.serializer()
314+
}
304315
}
305316
}
306317
}
@@ -410,6 +421,7 @@ internal fun InternalPart.toPublic(): Part {
410421
thought ?: false,
411422
thoughtSignature
412423
)
424+
is UnknownPart.Internal -> UnknownPart()
413425
else ->
414426
throw com.google.firebase.ai.type.SerializationException(
415427
"Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK."

firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ import com.google.firebase.ai.type.ResponseStoppedException
2323
import com.google.firebase.ai.type.ServerException
2424
import com.google.firebase.ai.util.goldenDevAPIStreamingFile
2525
import io.kotest.assertions.throwables.shouldThrow
26+
import io.kotest.matchers.collections.shouldBeEmpty
2627
import io.kotest.matchers.shouldBe
2728
import io.ktor.http.HttpStatusCode
2829
import kotlin.time.Duration.Companion.seconds
2930
import kotlinx.coroutines.flow.collect
3031
import kotlinx.coroutines.flow.toList
3132
import kotlinx.coroutines.withTimeout
3233
import org.junit.Test
34+
import org.junit.runner.RunWith
35+
import org.robolectric.RobolectricTestRunner
3336

37+
@RunWith(RobolectricTestRunner::class)
3438
internal class DevAPIStreamingSnapshotTests {
3539
private val testTimeout = 5.seconds
3640

@@ -64,6 +68,23 @@ internal class DevAPIStreamingSnapshotTests {
6468
}
6569
}
6670

71+
@Test
72+
fun `reply with a single empty part`() =
73+
goldenDevAPIStreamingFile("streaming-success-empty-parts.txt") {
74+
val responses = model.generateContentStream("prompt")
75+
76+
withTimeout(testTimeout) {
77+
val responseList = responses.toList()
78+
responseList.isEmpty() shouldBe false
79+
// Second to last response has no parts
80+
responseList[5].candidates.first().content.parts.shouldBeEmpty()
81+
responseList.last().candidates.first().apply {
82+
finishReason shouldBe FinishReason.STOP
83+
content.parts.isEmpty() shouldBe false
84+
}
85+
}
86+
}
87+
6788
@Test
6889
fun `prompt blocked for safety`() =
6990
goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") {

firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import com.google.firebase.ai.type.ServerException
2727
import com.google.firebase.ai.type.TextPart
2828
import com.google.firebase.ai.util.goldenVertexStreamingFile
2929
import io.kotest.assertions.throwables.shouldThrow
30+
import io.kotest.matchers.collections.shouldBeEmpty
31+
import io.kotest.matchers.collections.shouldHaveSize
3032
import io.kotest.matchers.nulls.shouldNotBeNull
3133
import io.kotest.matchers.shouldBe
3234
import io.kotest.matchers.string.shouldContain
@@ -155,7 +157,13 @@ internal class VertexAIStreamingSnapshotTests {
155157
goldenVertexStreamingFile("streaming-failure-empty-content.txt") {
156158
val responses = model.generateContentStream("prompt")
157159

158-
withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
160+
withTimeout(testTimeout) {
161+
withTimeout(testTimeout) {
162+
val responseList = responses.toList()
163+
responseList.shouldHaveSize(1)
164+
responseList.first().candidates.first().content.parts.shouldBeEmpty()
165+
}
166+
}
159167
}
160168

161169
@Test
@@ -241,6 +249,10 @@ internal class VertexAIStreamingSnapshotTests {
241249
goldenVertexStreamingFile("streaming-failure-malformed-content.txt") {
242250
val responses = model.generateContentStream("prompt")
243251

244-
withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
252+
withTimeout(testTimeout) {
253+
val responseList = responses.toList()
254+
responseList.shouldHaveSize(1)
255+
responseList.first().candidates.first().content.parts.shouldBeEmpty()
256+
}
245257
}
246258
}

firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import com.google.firebase.ai.util.goldenVertexUnaryFile
3838
import com.google.firebase.ai.util.shouldNotBeNullOrEmpty
3939
import io.kotest.assertions.throwables.shouldThrow
4040
import io.kotest.inspectors.forAtLeastOne
41+
import io.kotest.matchers.collections.shouldBeEmpty
4142
import io.kotest.matchers.collections.shouldNotBeEmpty
4243
import io.kotest.matchers.nulls.shouldBeNull
4344
import io.kotest.matchers.nulls.shouldNotBeNull
@@ -90,6 +91,19 @@ internal class VertexAIUnarySnapshotTests {
9091
}
9192
}
9293

94+
@Test
95+
fun `response including an empty part is handled gracefully`() =
96+
goldenVertexUnaryFile("unary-success-empty-part.json") {
97+
withTimeout(testTimeout) {
98+
val response = model.generateContent("prompt")
99+
100+
response.candidates.isEmpty() shouldBe false
101+
response.text.shouldNotBeEmpty()
102+
response.candidates.first().finishReason shouldBe FinishReason.STOP
103+
response.candidates.first().content.parts.isEmpty() shouldBe false
104+
}
105+
}
106+
93107
@Test
94108
fun `response with detailed token-based usageMetadata`() =
95109
goldenVertexUnaryFile("unary-success-basic-response-long-usage-metadata.json") {
@@ -246,7 +260,9 @@ internal class VertexAIUnarySnapshotTests {
246260
fun `empty content`() =
247261
goldenVertexUnaryFile("unary-failure-empty-content.json") {
248262
withTimeout(testTimeout) {
249-
shouldThrow<SerializationException> { model.generateContent("prompt") }
263+
val response = model.generateContent("prompt")
264+
response.candidates.shouldNotBeEmpty()
265+
response.candidates.first().content.parts.shouldBeEmpty()
250266
}
251267
}
252268

@@ -389,10 +405,12 @@ internal class VertexAIUnarySnapshotTests {
389405
}
390406

391407
@Test
392-
fun `malformed content`() =
408+
fun `response including an unknown part is handled gracefully`() =
393409
goldenVertexUnaryFile("unary-failure-malformed-content.json") {
394410
withTimeout(testTimeout) {
395-
shouldThrow<SerializationException> { model.generateContent("prompt") }
411+
val response = model.generateContent("prompt")
412+
response.candidates.shouldNotBeEmpty()
413+
response.candidates.first().content.parts.shouldBeEmpty()
396414
}
397415
}
398416

0 commit comments

Comments
 (0)