Skip to content

Commit 3a35cca

Browse files
committed
Merge remote-tracking branch 'origin/main' into MissingVersionsGradle
2 parents abdf968 + 5822579 commit 3a35cca

File tree

10 files changed

+115
-13
lines changed

10 files changed

+115
-13
lines changed

firebase-ai/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
- [feature] Added helper functions to `LiveSession` to allow developers to track the status of the
66
audio session and the underlying websocket connection.
77
- [changed] Added new values to `HarmCategory` (#7324)
8+
- [fixed] Fixed an issue that caused unknown or empty `Part`s to throw an exception. Instead, we now
9+
log them and filter them from the response (#7333)
810

911
# 17.2.0
1012

firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ internal constructor(
177177
GenerateImageRequest.ImagenParameters(
178178
sampleCount = generationConfig?.numberOfImages ?: 1,
179179
includeRaiReason = true,
180+
includeSafetyAttributes = true,
180181
addWatermark = generationConfig?.addWatermark,
181182
personGeneration = safetySettings?.personFilterLevel?.internalVal,
182183
negativePrompt = generationConfig?.negativePrompt,
@@ -206,6 +207,7 @@ internal constructor(
206207
GenerateImageRequest.ImagenParameters(
207208
sampleCount = generationConfig?.numberOfImages ?: 1,
208209
includeRaiReason = true,
210+
includeSafetyAttributes = true,
209211
addWatermark = generationConfig?.addWatermark,
210212
personGeneration = safetySettings?.personFilterLevel?.internalVal,
211213
negativePrompt = generationConfig?.negativePrompt,

firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ internal data class GenerateImageRequest(
9393
internal data class ImagenParameters(
9494
val sampleCount: Int,
9595
val includeRaiReason: Boolean,
96+
val includeSafetyAttributes: Boolean,
9697
val storageUri: String?,
9798
val negativePrompt: String?,
9899
val aspectRatio: String?,

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.firebase.ai.type
1818

1919
import android.graphics.Bitmap
20+
import kotlin.collections.filterNot
2021
import kotlinx.serialization.EncodeDefault
2122
import kotlinx.serialization.ExperimentalSerializationApi
2223
import kotlinx.serialization.Serializable
@@ -90,14 +91,21 @@ constructor(public val role: String? = "user", public val parts: List<Part>) {
9091
@Serializable
9192
internal data class Internal(
9293
@EncodeDefault val role: String? = "user",
93-
val parts: List<InternalPart>
94+
val parts: List<InternalPart>? = null
9495
) {
9596
internal fun toPublic(): Content {
97+
// Return empty if none of the parts is a known part
98+
if (parts == null || parts.filterNot { it is UnknownPart.Internal }.isEmpty()) {
99+
return Content(role, emptyList())
100+
}
101+
// From all the known parts, if they are all text and empty, we coalesce them into a single
102+
// one-character string part so the backend doesn't fail if we send this back as part of a
103+
// multi-turn interaction.
96104
val returnedParts =
97-
parts.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() }
98-
// If all returned parts were text and empty, we coalesce them into a single one-character
99-
// string
100-
// part so the backend doesn't fail if we send this back as part of a multi-turn interaction.
105+
parts
106+
.filterNot { it is UnknownPart.Internal }
107+
.map { it.toPublic() }
108+
.filterNot { it is TextPart && it.text.isEmpty() }
101109
return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) })
102110
}
103111
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenGenerationResponse.kt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
4242
internal fun toPublicInline() =
4343
ImagenGenerationResponse(
4444
images = predictions.filter { it.mimeType != null }.map { it.toPublicInline() },
45-
null,
45+
predictions.firstNotNullOfOrNull { it.raiFilteredReason },
4646
)
4747
}
4848

@@ -52,10 +52,24 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
5252
val gcsUri: String? = null,
5353
val mimeType: String? = null,
5454
val raiFilteredReason: String? = null,
55+
val safetyAttributes: ImagenSafetyAttributes? = null,
5556
) {
5657
internal fun toPublicInline() =
5758
ImagenInlineImage(Base64.decode(bytesBase64Encoded!!, Base64.NO_WRAP), mimeType!!)
5859

5960
internal fun toPublicGCS() = ImagenGCSImage(gcsUri!!, mimeType!!)
6061
}
62+
63+
@Serializable
64+
internal data class ImagenSafetyAttributes(
65+
val categories: List<String>? = null,
66+
val scores: List<Double>? = null
67+
) {
68+
internal fun toPublic(): Map<String, Double> {
69+
if (categories == null || scores == null) {
70+
return emptyMap()
71+
}
72+
return categories.zip(scores).toMap()
73+
}
74+
}
6175
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenInlineImage.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ import kotlinx.serialization.Serializable
3131
*/
3232
@PublicPreviewAPI
3333
public class ImagenInlineImage
34-
internal constructor(public val data: ByteArray, public val mimeType: String) {
34+
internal constructor(
35+
public val data: ByteArray,
36+
public val mimeType: String,
37+
) {
3538

3639
/**
3740
* Returns the image as an Android OS native [Bitmap] so that it can be saved or sent to the UI.

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package com.google.firebase.ai.type
1818

1919
import android.graphics.Bitmap
2020
import android.graphics.BitmapFactory
21+
import android.util.Log
2122
import java.io.ByteArrayOutputStream
2223
import kotlinx.serialization.DeserializationStrategy
2324
import kotlinx.serialization.SerialName
@@ -270,6 +271,10 @@ internal constructor(
270271
}
271272
}
272273

274+
internal data class UnknownPart(public override val isThought: Boolean = false) : Part {
275+
@Serializable internal data class Internal(val thought: Boolean? = null) : InternalPart
276+
}
277+
273278
/** Returns the part as a [String] if it represents text, and null otherwise */
274279
public fun Part.asTextOrNull(): String? = (this as? TextPart)?.text
275280

@@ -290,6 +295,9 @@ internal const val BASE_64_FLAGS = android.util.Base64.NO_WRAP
290295

291296
internal object PartSerializer :
292297
JsonContentPolymorphicSerializer<InternalPart>(InternalPart::class) {
298+
299+
private val TAG = PartSerializer::javaClass.name
300+
293301
override fun selectDeserializer(element: JsonElement): DeserializationStrategy<InternalPart> {
294302
val jsonObject = element.jsonObject
295303
return when {
@@ -300,7 +308,10 @@ internal object PartSerializer :
300308
"functionResponse" in jsonObject -> FunctionResponsePart.Internal.serializer()
301309
"inlineData" in jsonObject -> InlineDataPart.Internal.serializer()
302310
"fileData" in jsonObject -> FileDataPart.Internal.serializer()
303-
else -> throw SerializationException("Unknown Part type")
311+
else -> {
312+
Log.w(TAG, "Unknown part type received, ignoring.")
313+
UnknownPart.Internal.serializer()
314+
}
304315
}
305316
}
306317
}
@@ -410,6 +421,7 @@ internal fun InternalPart.toPublic(): Part {
410421
thought ?: false,
411422
thoughtSignature
412423
)
424+
is UnknownPart.Internal -> UnknownPart()
413425
else ->
414426
throw com.google.firebase.ai.type.SerializationException(
415427
"Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK."

firebase-ai/src/test/java/com/google/firebase/ai/DevAPIStreamingSnapshotTests.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ import com.google.firebase.ai.type.ResponseStoppedException
2323
import com.google.firebase.ai.type.ServerException
2424
import com.google.firebase.ai.util.goldenDevAPIStreamingFile
2525
import io.kotest.assertions.throwables.shouldThrow
26+
import io.kotest.matchers.collections.shouldBeEmpty
2627
import io.kotest.matchers.shouldBe
2728
import io.ktor.http.HttpStatusCode
2829
import kotlin.time.Duration.Companion.seconds
2930
import kotlinx.coroutines.flow.collect
3031
import kotlinx.coroutines.flow.toList
3132
import kotlinx.coroutines.withTimeout
3233
import org.junit.Test
34+
import org.junit.runner.RunWith
35+
import org.robolectric.RobolectricTestRunner
3336

37+
@RunWith(RobolectricTestRunner::class)
3438
internal class DevAPIStreamingSnapshotTests {
3539
private val testTimeout = 5.seconds
3640

@@ -64,6 +68,23 @@ internal class DevAPIStreamingSnapshotTests {
6468
}
6569
}
6670

71+
@Test
72+
fun `reply with a single empty part`() =
73+
goldenDevAPIStreamingFile("streaming-success-empty-parts.txt") {
74+
val responses = model.generateContentStream("prompt")
75+
76+
withTimeout(testTimeout) {
77+
val responseList = responses.toList()
78+
responseList.isEmpty() shouldBe false
79+
// Second to last response has no parts
80+
responseList[5].candidates.first().content.parts.shouldBeEmpty()
81+
responseList.last().candidates.first().apply {
82+
finishReason shouldBe FinishReason.STOP
83+
content.parts.isEmpty() shouldBe false
84+
}
85+
}
86+
}
87+
6788
@Test
6889
fun `prompt blocked for safety`() =
6990
goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") {

firebase-ai/src/test/java/com/google/firebase/ai/VertexAIStreamingSnapshotTests.kt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import com.google.firebase.ai.type.ServerException
2727
import com.google.firebase.ai.type.TextPart
2828
import com.google.firebase.ai.util.goldenVertexStreamingFile
2929
import io.kotest.assertions.throwables.shouldThrow
30+
import io.kotest.matchers.collections.shouldBeEmpty
31+
import io.kotest.matchers.collections.shouldHaveSize
3032
import io.kotest.matchers.nulls.shouldNotBeNull
3133
import io.kotest.matchers.shouldBe
3234
import io.kotest.matchers.string.shouldContain
@@ -155,7 +157,13 @@ internal class VertexAIStreamingSnapshotTests {
155157
goldenVertexStreamingFile("streaming-failure-empty-content.txt") {
156158
val responses = model.generateContentStream("prompt")
157159

158-
withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
160+
withTimeout(testTimeout) {
161+
withTimeout(testTimeout) {
162+
val responseList = responses.toList()
163+
responseList.shouldHaveSize(1)
164+
responseList.first().candidates.first().content.parts.shouldBeEmpty()
165+
}
166+
}
159167
}
160168

161169
@Test
@@ -241,6 +249,10 @@ internal class VertexAIStreamingSnapshotTests {
241249
goldenVertexStreamingFile("streaming-failure-malformed-content.txt") {
242250
val responses = model.generateContentStream("prompt")
243251

244-
withTimeout(testTimeout) { shouldThrow<SerializationException> { responses.collect() } }
252+
withTimeout(testTimeout) {
253+
val responseList = responses.toList()
254+
responseList.shouldHaveSize(1)
255+
responseList.first().candidates.first().content.parts.shouldBeEmpty()
256+
}
245257
}
246258
}

firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import com.google.firebase.ai.util.goldenVertexUnaryFile
3838
import com.google.firebase.ai.util.shouldNotBeNullOrEmpty
3939
import io.kotest.assertions.throwables.shouldThrow
4040
import io.kotest.inspectors.forAtLeastOne
41+
import io.kotest.matchers.collections.shouldBeEmpty
4142
import io.kotest.matchers.collections.shouldNotBeEmpty
4243
import io.kotest.matchers.nulls.shouldBeNull
4344
import io.kotest.matchers.nulls.shouldNotBeNull
@@ -90,6 +91,19 @@ internal class VertexAIUnarySnapshotTests {
9091
}
9192
}
9293

94+
@Test
95+
fun `response including an empty part is handled gracefully`() =
96+
goldenVertexUnaryFile("unary-success-empty-part.json") {
97+
withTimeout(testTimeout) {
98+
val response = model.generateContent("prompt")
99+
100+
response.candidates.isEmpty() shouldBe false
101+
response.text.shouldNotBeEmpty()
102+
response.candidates.first().finishReason shouldBe FinishReason.STOP
103+
response.candidates.first().content.parts.isEmpty() shouldBe false
104+
}
105+
}
106+
93107
@Test
94108
fun `response with detailed token-based usageMetadata`() =
95109
goldenVertexUnaryFile("unary-success-basic-response-long-usage-metadata.json") {
@@ -246,7 +260,9 @@ internal class VertexAIUnarySnapshotTests {
246260
fun `empty content`() =
247261
goldenVertexUnaryFile("unary-failure-empty-content.json") {
248262
withTimeout(testTimeout) {
249-
shouldThrow<SerializationException> { model.generateContent("prompt") }
263+
val response = model.generateContent("prompt")
264+
response.candidates.shouldNotBeEmpty()
265+
response.candidates.first().content.parts.shouldBeEmpty()
250266
}
251267
}
252268

@@ -389,10 +405,12 @@ internal class VertexAIUnarySnapshotTests {
389405
}
390406

391407
@Test
392-
fun `malformed content`() =
408+
fun `response including an unknown part is handled gracefully`() =
393409
goldenVertexUnaryFile("unary-failure-malformed-content.json") {
394410
withTimeout(testTimeout) {
395-
shouldThrow<SerializationException> { model.generateContent("prompt") }
411+
val response = model.generateContent("prompt")
412+
response.candidates.shouldNotBeEmpty()
413+
response.candidates.first().content.parts.shouldBeEmpty()
396414
}
397415
}
398416

@@ -591,6 +609,15 @@ internal class VertexAIUnarySnapshotTests {
591609
}
592610
}
593611

612+
@Test
613+
fun `generateImages should contain safety data`() =
614+
goldenVertexUnaryFile("unary-success-generate-images-safety_info.json") {
615+
withTimeout(testTimeout) {
616+
val response = imagenModel.generateImages("prompt")
617+
// There is no public API, but if it parses then success
618+
}
619+
}
620+
594621
@Test
595622
fun `google search grounding metadata is parsed correctly`() =
596623
goldenVertexUnaryFile("unary-success-google-search-grounding.json") {

0 commit comments

Comments
 (0)